GokseninYuksel commited on
Commit
f0e612b
·
verified ·
1 Parent(s): 83c8468

Upload model

Browse files
Files changed (12) hide show
  1. Patcher.py +103 -0
  2. config.json +35 -0
  3. configuration_gramt_binaural_time.py +49 -0
  4. droppath.py +41 -0
  5. model.py +309 -0
  6. model.safetensors +3 -0
  7. modeling_gramt_binaural_time.py +41 -0
  8. mwmae.py +434 -0
  9. patching_utils.py +126 -0
  10. pos_embed.py +210 -0
  11. swin.py +522 -0
  12. utils.py +249 -0
Patcher.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+
3
+ from .patching_utils import combine_patches, generate_patches, get_shape
4
+
5
+
6
+ class PatchStrategy(ABC):
7
+ def __init__(self, tstride, tshape, fstride, fshape, input_fdim, input_tdim):
8
+ self.tstride = tstride
9
+ self.tshape = tshape
10
+ self.fstride = fstride
11
+ self.fshape = fshape
12
+ self.input_fdim = input_fdim
13
+ self.input_tdim = input_tdim
14
+
15
+ def _patch(self, x):
16
+ patches = generate_patches(
17
+ input=x,
18
+ fstride=self.fstride,
19
+ tstride=self.tstride,
20
+ fshape=self.fshape,
21
+ tshape=self.tshape,
22
+ )
23
+ return patches
24
+
25
+ def patch(self, x):
26
+ return self._patch(x)
27
+
28
+ def embed(self, x, patch_embedder):
29
+ return patch_embedder(x)
30
+
31
+ def patch_and_embed(self, x, patch_embedder):
32
+ """
33
+ Generate patches from the input spectrogram and embed them.
34
+
35
+ This method creates patches based on the frequency and temporal stride/shape
36
+ parameters, and then applies the given patch embedding function.
37
+
38
+ Parameters
39
+ ----------
40
+ x : torch.Tensor
41
+ The input spectrogram tensor to be patched and embedded.
42
+ patch_embedder : Callable
43
+ A function that applies embedding to the patches.
44
+
45
+ Returns
46
+ -------
47
+ Tuple[torch.Tensor, torch.Tensor]
48
+ The generated patches and their embeddings.
49
+ """
50
+ # Generate patches for knowing the input.
51
+ patches = generate_patches(
52
+ input=x,
53
+ fstride=self.fstride,
54
+ tstride=self.tstride,
55
+ fshape=self.fshape,
56
+ tshape=self.tshape,
57
+ )
58
+ x = patch_embedder(x)
59
+ return patches, x
60
+
61
+ def get_patch_size(self):
62
+ p_f_dim, p_t_dim = get_shape(
63
+ fstride=self.fstride,
64
+ tstride=self.tstride,
65
+ input_fdim=self.input_fdim,
66
+ input_tdim=self.input_tdim,
67
+ fshape=self.fshape,
68
+ tshape=self.tshape,
69
+ )
70
+ return p_f_dim, p_t_dim
71
+
72
+ def combine_patches(self, patches, original_size):
73
+ return combine_patches(
74
+ patches, original_size, self.fstride, self.tstride, self.fshape, self.tshape
75
+ )
76
+
77
+
78
+ class TimePatching(PatchStrategy):
79
+ def __init__(
80
+ self, input_tdim, tstride=2, tshape=2, fstride=128, fshape=128, input_fdim=128
81
+ ):
82
+ super().__init__(
83
+ tstride=tstride,
84
+ tshape=tshape,
85
+ fstride=fstride,
86
+ fshape=fshape,
87
+ input_fdim=input_fdim,
88
+ input_tdim=input_tdim,
89
+ )
90
+
91
+
92
+ class FramePatching(PatchStrategy):
93
+ def __init__(
94
+ self, input_tdim, tstride=16, tshape=16, fstride=16, fshape=16, input_fdim=128
95
+ ):
96
+ super().__init__(
97
+ tstride=tstride,
98
+ tshape=tshape,
99
+ fstride=fstride,
100
+ fshape=fshape,
101
+ input_fdim=input_fdim,
102
+ input_tdim=input_tdim,
103
+ )
config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "GRAMTBinauralTimeModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_gramt_binaural_time.GRAMTBinauralTimeConfig",
7
+ "AutoModel": "modeling_gramt_binaural_time.GRAMTBinauralTimeModel"
8
+ },
9
+ "decoder_depth": 8,
10
+ "decoder_embedding_dim": 512,
11
+ "decoder_mlp_ratio": 4.0,
12
+ "decoder_num_heads": 8,
13
+ "decoder_window_sizes": [
14
+ 2,
15
+ 5,
16
+ 10,
17
+ 25,
18
+ 50,
19
+ 0,
20
+ 0,
21
+ 0
22
+ ],
23
+ "encoder_attention_dropout": 0.0,
24
+ "encoder_dropout": 0.0,
25
+ "encoder_hidden_dim": 768,
26
+ "encoder_mlp_ratio": 4.0,
27
+ "encoder_norm_layer_eps": 1e-06,
28
+ "encoder_num_heads": 12,
29
+ "encoder_num_layers": 12,
30
+ "input_length": 200,
31
+ "model_type": "gramt-binaural-time",
32
+ "num_mel_bins": 128,
33
+ "torch_dtype": "float32",
34
+ "transformers_version": "4.46.3"
35
+ }
configuration_gramt_binaural_time.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+
5
+ class GRAMTBinauralTimeConfig(PretrainedConfig):
6
+ model_type = "gramt-binaural-time"
7
+ model_size = "base"
8
+ in_channels: int = 2
9
+ patch_size = (128,2)
10
+ frequency_stride = 128
11
+ time_stride = 2
12
+
13
+ def __init__(
14
+ self,
15
+ decoder_mlp_ratio: float = 4.0,
16
+ decoder_depth: int = 8,
17
+ decoder_num_heads: int = 8,
18
+ decoder_embedding_dim: int = 512,
19
+ decoder_window_sizes: List[int] = [2, 5, 10, 25, 50, 0, 0, 0],
20
+ encoder_num_layers = 12,
21
+ encoder_num_heads = 12,
22
+ encoder_hidden_dim = 768,
23
+ encoder_mlp_ratio = 4.0,
24
+ encoder_dropout = 0.0,
25
+ encoder_attention_dropout = 0.0,
26
+ encoder_norm_layer_eps = 1e-6,
27
+ input_length = 200,
28
+ num_mel_bins = 128,
29
+ **kwargs,
30
+ ):
31
+
32
+ self.decoder_mlp_ratio = decoder_mlp_ratio
33
+ self.decoder_depth = decoder_depth
34
+ self.decoder_num_heads = decoder_num_heads
35
+ self.decoder_embedding_dim = decoder_embedding_dim
36
+ self.decoder_window_sizes = decoder_window_sizes
37
+
38
+ self.encoder_num_layers = encoder_num_layers
39
+ self.encoder_num_heads = encoder_num_heads
40
+ self.encoder_hidden_dim = encoder_hidden_dim
41
+ self.encoder_mlp_ratio = encoder_mlp_ratio
42
+ self.encoder_dropout = encoder_dropout
43
+ self.encoder_attention_dropout = encoder_attention_dropout
44
+ self.encoder_norm_layer_eps = encoder_norm_layer_eps
45
+
46
+
47
+ self.input_length = input_length
48
+ self.num_mel_bins = num_mel_bins
49
+ super().__init__(**kwargs)
droppath.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of DropPath (Stochastic Depth) regularization
3
+
4
+ Inspired by the PyTorch implementation in timm (https://github.com/rwightman/pytorch-image-models)
5
+ by Ross Wightman, 2022
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+
12
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
13
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
14
+
15
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
16
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
17
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
18
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
19
+ 'survival rate' as the argument.
20
+ """
21
+ if drop_prob == 0.0 or not training:
22
+ return x
23
+ keep_prob = 1 - drop_prob
24
+ shape = (x.shape[0],) + (1,) * (
25
+ x.ndim - 1
26
+ ) # work with diff dim tensors, not just 2D ConvNets
27
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
28
+ random_tensor.floor_() # binarize
29
+ output = x.div(keep_prob) * random_tensor
30
+ return output
31
+
32
+
33
+ class DropPath(nn.Module):
34
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
35
+
36
+ def __init__(self, drop_prob=0.0):
37
+ super(DropPath, self).__init__()
38
+ self.drop_prob = drop_prob
39
+
40
+ def forward(self, x, training=True):
41
+ return drop_path(x, self.drop_prob, training)
model.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ from .Patcher import PatchStrategy
6
+ from .mwmae import MWMHABlock
7
+ from .pos_embed import get_2d_sincos_pos_embed
8
+ from .utils import PatchEmbed, create_pretrained_model, repeat_token
9
+
10
+ from einops import rearrange
11
+
12
+
13
+ def conv3x3(in_channels, out_channels, stride=1):
14
+ "3x3 convolution with padding"
15
+ return nn.Conv2d(
16
+ in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False
17
+ )
18
+
19
+
20
+ class GRAMT(nn.Module):
21
+ def __init__(
22
+ self,
23
+ model_size="base",
24
+ in_channels = 2,
25
+ decoder_mlp_ratio: float = 4.0,
26
+ decoder_depth: int = 8,
27
+ decoder_num_heads: int = 8,
28
+ decoder_embedding_dim: int = 512,
29
+ decoder_window_sizes: list[int] = [2, 5, 10, 25, 50, 100, 0, 0],
30
+ encoder_num_layers = 12,
31
+ encoder_num_heads = 12,
32
+ encoder_hidden_dim = 768,
33
+ encoder_mlp_ratio = 4.0,
34
+ encoder_dropout = 0.0,
35
+ encoder_attention_dropout = 0.0,
36
+ encoder_norm_layer_eps = 1e-6,
37
+ patch_size = (16,8),
38
+ frequency_stride = 16,
39
+ time_stride = 8,
40
+ input_length = 200,
41
+ num_mel_bins = 128,
42
+ **kwargs,
43
+ ):
44
+ super().__init__()
45
+ self.in_channels = in_channels
46
+ self.input_length = input_length
47
+ # Calculate intermediate shape after masking
48
+ self.patch_strategy = PatchStrategy(tstride = time_stride,
49
+ tshape = patch_size[1],
50
+ fstride = frequency_stride,
51
+ fshape = patch_size[0],
52
+ input_fdim = num_mel_bins,
53
+ input_tdim = self.input_length)
54
+ self.p_f_dim, self.p_t_dim = self.patch_strategy.get_patch_size()
55
+ self.num_patches = self.p_f_dim * self.p_t_dim
56
+ self.grid_size = (self.p_f_dim, self.p_t_dim)
57
+
58
+ # This is our encoder.
59
+ # --------------------------------------------------------------------------
60
+
61
+ # Transformer
62
+ (
63
+ self.encoder,
64
+ self.encoder_embedding_dim,
65
+ ) = create_pretrained_model(model_size,
66
+ encoder_num_layers = encoder_num_layers,
67
+ encoder_num_heads = encoder_num_heads,
68
+ encoder_hidden_dim = encoder_hidden_dim,
69
+ encoder_mlp_dim = int(encoder_hidden_dim * encoder_mlp_ratio),
70
+ encoder_dropout = encoder_dropout,
71
+ encoder_attention_dropout = encoder_attention_dropout,
72
+ encoder_norm_layer_eps = encoder_norm_layer_eps)
73
+ self.encoder_cls_token_num = 1
74
+
75
+ # Patch Embedder
76
+ self.patch_embed = PatchEmbed()
77
+ self._update_patch_embed_layers(self.patch_embed)
78
+
79
+ # Norm/Pos
80
+ self.register_buffer("cls_token",nn.Parameter(torch.zeros([1, 1, self.encoder_embedding_dim]), requires_grad = True))
81
+ torch.nn.init.normal_(self.cls_token, std=0.02)
82
+
83
+ # This is our decoder.
84
+ # --------------------------------------------------------------------------
85
+ # MAE decoder specifics
86
+ self.decoder_depth = decoder_depth
87
+ self.decoder_num_heads = decoder_num_heads
88
+ self.decoder_embedding_dim = decoder_embedding_dim
89
+ self.decoder_window_sizes = decoder_window_sizes
90
+ self.decoder_embed = nn.Linear(
91
+ self.encoder_embedding_dim, self.decoder_embedding_dim, bias=True
92
+ )
93
+
94
+ self.register_buffer("mask_token", nn.Parameter(torch.zeros(1, 1, self.decoder_embedding_dim, requires_grad = True)))
95
+ torch.nn.init.normal_(self.mask_token, std=0.02)
96
+ self.decoder_blocks = nn.ModuleList(
97
+ [
98
+ MWMHABlock(
99
+ dim=decoder_embedding_dim,
100
+ num_heads=decoder_num_heads,
101
+ window_sizes=decoder_window_sizes,
102
+ shift_windows=False,
103
+ mlp_ratio=decoder_mlp_ratio,
104
+ qkv_bias=True,
105
+ norm_layer=nn.LayerNorm,
106
+ )
107
+ for i in range(self.decoder_depth)
108
+ ]
109
+ )
110
+ cls_token_num = 0
111
+ self.encoder.pos_embedding = self._get_pos_embed_params()
112
+ # Pos Embed init w/o the cls token num
113
+ self.register_buffer("decoder_pos_embed", nn.Parameter(
114
+ torch.zeros(1, self.num_patches, decoder_embedding_dim),
115
+ requires_grad=False,
116
+ ))
117
+ pos_embed = get_2d_sincos_pos_embed(
118
+ decoder_embedding_dim, self.grid_size, cls_token_num=cls_token_num
119
+ )
120
+ self.decoder_pos_embed.data.copy_(
121
+ torch.from_numpy(pos_embed).float().unsqueeze(0)
122
+ )
123
+ # Define prediction layers for Masked Auto Encoder pretraining
124
+ self.spec_pred = nn.Sequential(
125
+ nn.Linear(
126
+ decoder_embedding_dim,
127
+ self.patch_strategy.fshape
128
+ * self.patch_strategy.tshape
129
+ * self.in_channels,
130
+ bias=True,
131
+ ),
132
+ )
133
+ self.decoder_norm = nn.LayerNorm(decoder_embedding_dim)
134
+ # Normalize binaural/ambisonic spectrograms with Layer norm later.
135
+ self.spectrogram_normalize = nn.LayerNorm(
136
+ [self.in_channels, num_mel_bins, self.input_length],
137
+ elementwise_affine=False
138
+ )
139
+ self.input_shape = [num_mel_bins, self.input_length]
140
+ compile_modules = kwargs.get("compile_modules", None)
141
+ if (compile_modules is not None) and (compile_modules):
142
+ self._compile_operations()
143
+
144
+
145
+ def _compile_operations(self):
146
+ """
147
+ Use torch.compile on the extractor, encoder and decoder blocks for faster forward
148
+ """
149
+ try:
150
+ self.forward = torch.compile(self.get_audio_representation, mode = "reduce-overhead")
151
+ except Exception as e:
152
+ print(f"Warning: Could not compile operations: {e}")
153
+ self.use_compiled_forward = False
154
+
155
+
156
+
157
+ def _get_pos_embed_params(self):
158
+ """Calculates the pos embedding embedding parameters and returns them."""
159
+ # Update positional embedding
160
+ pos_embed = nn.Parameter(
161
+ torch.zeros(
162
+ 1,
163
+ self.num_patches + self.encoder_cls_token_num,
164
+ self.encoder_embedding_dim,
165
+ ),
166
+ requires_grad=False,
167
+ )
168
+ pos_embed_data = get_2d_sincos_pos_embed(
169
+ self.encoder_embedding_dim,
170
+ self.grid_size,
171
+ cls_token_num=self.encoder_cls_token_num,
172
+ )
173
+ pos_embed.data.copy_(torch.from_numpy(pos_embed_data).float().unsqueeze(0))
174
+ return pos_embed
175
+
176
+ def _update_patch_embed_layers(self, patch_embed):
177
+ """Updates the patch embedding embedding layers."""
178
+ # Update patch projection layer
179
+ # Use 2, as the spectrogram has 2 channels
180
+ patch_embed.proj = torch.nn.Conv2d(
181
+ self.in_channels,
182
+ self.encoder_embedding_dim,
183
+ kernel_size=(self.patch_strategy.fshape, self.patch_strategy.tshape),
184
+ stride=(self.patch_strategy.fstride, self.patch_strategy.tstride),
185
+ )
186
+ patch_embed.num_patch = self.num_patches
187
+
188
+ def pass_through_encoder(self, x, non_mask_index, B):
189
+ """Passes the input through the Encoder Transformer network."""
190
+ # Add positional embeddings to the x.
191
+ x = x + self.encoder.pos_embedding[:, self.encoder_cls_token_num :, :]
192
+ x = x[non_mask_index, :].reshape((B, -1, x.shape[-1]))
193
+ cls_token = (
194
+ self.cls_token.expand(B, -1, -1)
195
+ + self.encoder.pos_embedding[:, :1, :]
196
+ )
197
+
198
+ try:
199
+ dist_token = (
200
+ self.encoder.dist_token.expand(B, -1, -1)
201
+ + self.encoder.pos_embedding[:, 1:2, :]
202
+ )
203
+ x = torch.cat((cls_token, dist_token, x), dim=1)
204
+
205
+ except Exception as e:
206
+ x = torch.cat((cls_token, x), dim=1)
207
+
208
+
209
+ x = self.encoder.dropout(x)
210
+ for block in self.encoder.layers:
211
+ x = block(x)
212
+ return self.encoder.ln(x)
213
+
214
+
215
+ def pass_through_decoder(self, encoder_output, non_mask_index, B):
216
+ encoder_output = self.decoder_embed(encoder_output)
217
+ x_ = repeat_token(
218
+ self.mask_token, (B, self.num_patches)
219
+ ).type_as(encoder_output)
220
+ x_[non_mask_index, :] = encoder_output[
221
+ :, self.encoder_cls_token_num :, :
222
+ ].reshape((-1, encoder_output.shape[-1]))
223
+ x_ = x_.reshape((B, -1, encoder_output.shape[-1]))
224
+
225
+ # Concatenate the CLS and Possibly Distill tokens from the encoder
226
+ # We can not do it with multi windowed attention though!
227
+ # So remove the CLS token from the decoder!
228
+ if self.use_mwmae_decoder:
229
+ x = x_
230
+ return_cut = 0
231
+ else:
232
+ x = torch.cat(
233
+ [encoder_output[:, : self.encoder_cls_token_num, :], x_], dim=1
234
+ )
235
+ return_cut = self.encoder_cls_token_num
236
+ x = x + self.decoder_pos_embed # add the pos embeds
237
+ # Pass through transformer blocks
238
+ for blk in self.decoder_blocks:
239
+ x = blk(x)
240
+ x = self.decoder_norm(x)
241
+ pred = self.spec_pred(x)
242
+ pred = pred[:, return_cut:, :]
243
+ return pred
244
+
245
+
246
+
247
+ def _get_segment_representation(self, x, strategy="mean"):
248
+ """Extract audio representation using different strategies."""
249
+ # Put the model in eval mode when getting representations.
250
+ assert x.shape[1] == self.in_channels, f"The GRAM has in channels {self.in_channels}, but the feature has shape {x.shape} which the channels are incompatible"
251
+ B = x.shape[0]
252
+ x = x.transpose(2, 3)
253
+ x = self.spectrogram_normalize(x)
254
+ patches = self.patch_strategy.patch(x)
255
+ patches = patches.flatten(2)
256
+ encoded_patches = self.patch_strategy.embed(x, self.patch_embed)
257
+ mask = torch.zeros((B, self.num_patches), dtype=torch.bool, device=x.device)
258
+ x = self.pass_through_encoder(encoded_patches, ~mask, B)
259
+ if strategy == "mean":
260
+ return x[:, self.encoder_cls_token_num :, :].mean(axis=1)
261
+ elif strategy == "sum":
262
+ return x[:, self.encoder_cls_token_num :, :].sum(axis=1)
263
+ elif strategy == "cls":
264
+ return x[:, 0, :]
265
+ elif strategy == "raw":
266
+ x = x[:, self.encoder_cls_token_num :, :]
267
+ grid_size = self.grid_size
268
+ f, t = grid_size
269
+ # We have 25 time patches in 2 second audio. We need to have 20 for STARSS22.
270
+ outcome = rearrange(
271
+ x, "b (f t) d -> b t (f d)", f=f, d=self.encoder_embedding_dim
272
+ )
273
+ return outcome
274
+ else:
275
+ raise ValueError(f"Strategy '{strategy}' is unrecognized.")
276
+
277
+ def get_audio_representation(self, x, strategy = "mean"):
278
+ unit_frames = self.input_length
279
+ cur_frames = x.shape[2]
280
+ pad_frames = unit_frames - (cur_frames % unit_frames)
281
+ if pad_frames > 0:
282
+ # Padding with constant 0s
283
+ pad_arg = (
284
+ 0,
285
+ 0,
286
+ 0,
287
+ pad_frames,
288
+ ) # (channel, channel, height, height, width, width)
289
+ x = torch.nn.functional.pad(x, pad_arg, mode="constant")
290
+
291
+ embeddings = []
292
+ # Now get the embeddings of the model.
293
+ for i in range(x.shape[2] // unit_frames):
294
+ x_inp = x[:, :, i * unit_frames : (i + 1) * unit_frames, :]
295
+ with torch.no_grad():
296
+ embedding = self._get_segment_representation(
297
+ x_inp, strategy=strategy
298
+ )
299
+ embeddings.append(embedding)
300
+ # Stack the embeddings here if it is raw
301
+ if strategy == "raw":
302
+ x = torch.hstack(embeddings)
303
+ pad_emb_frames = int(embeddings[0].shape[1] * pad_frames / unit_frames)
304
+ if pad_emb_frames > 0:
305
+ x = x[:, :-pad_emb_frames] # remove padded tail
306
+ return x
307
+ else:
308
+ x = torch.stack(embeddings, dim=1)
309
+ return x
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d5b20b8decda54204192d9d7eae5fdd70e93bb1071e3d4cdebf261fd7e7d160
3
+ size 446080184
modeling_gramt_binaural_time.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from transformers import AutoConfig, AutoModel
3
+
4
+ from .model import GRAMT
5
+ from .configuration_gramt_binaural_time import GRAMTBinauralTimeConfig
6
+
7
+
8
+ class GRAMTBinauralTimeModel(PreTrainedModel):
9
+ config_class = GRAMTBinauralTimeConfig
10
+
11
+ def __init__(self, config):
12
+ super().__init__(config)
13
+ self.model = GRAMT(
14
+ in_channels = config.in_channels,
15
+ decoder_mlp_ratio = config.decoder_mlp_ratio,
16
+ decoder_depth = config.decoder_depth,
17
+ decoder_num_heads = config.decoder_num_heads,
18
+ decoder_embedding_dim = config.decoder_embedding_dim,
19
+ decoder_window_sizes = config.decoder_window_sizes,
20
+ encoder_num_layers = config.encoder_num_layers,
21
+ encoder_num_heads = config.encoder_num_heads,
22
+ encoder_hidden_dim = config.encoder_hidden_dim,
23
+ encoder_mlp_ratio = config.encoder_mlp_ratio,
24
+ encoder_dropout = config.encoder_dropout,
25
+ encoder_attention_dropout = config.encoder_attention_dropout,
26
+ encoder_norm_layer_eps = config.encoder_norm_layer_eps,
27
+ patch_size = config.patch_size,
28
+ frequency_stride = config.frequency_stride,
29
+ time_stride = config.time_stride,
30
+ max_length = config.max_length,
31
+ num_mel_bins = config.num_mel_bins
32
+ )
33
+
34
+ def forward(self, tensor, strategy = "raw"):
35
+ return self.model.get_audio_representation(tensor, strategy = strategy)
36
+
37
+
38
+
39
+ gram = GRAMTBinauralTimeModel(GRAMTBinauralTimeConfig())
40
+ AutoConfig.register("gramt-binaural-time", GRAMTBinauralTimeConfig)
41
+ AutoModel.register(GRAMTBinauralTimeConfig, GRAMTBinauralTimeModel)
mwmae.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+ from itertools import repeat
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from .droppath import DropPath
9
+ from .swin import Mlp
10
+
11
+
12
+ def constant_init(tensor, constant=0.0):
13
+ nn.init.constant_(tensor, constant)
14
+ return tensor
15
+
16
+
17
+ def _ntuple(n):
18
+ def parse(x):
19
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
20
+ return x
21
+ return tuple(repeat(x, n))
22
+
23
+ return parse
24
+
25
+
26
+ class Mlp(nn.Module):
27
+ def __init__(
28
+ self,
29
+ in_features=None,
30
+ hidden_features=None,
31
+ out_features=None,
32
+ activation=F.gelu,
33
+ drop=0.0,
34
+ ):
35
+ super().__init__()
36
+ out_features = out_features or in_features
37
+ hidden_features = hidden_features or in_features
38
+ self.fc1 = nn.Linear(in_features, hidden_features)
39
+ self.act = activation
40
+ self.fc2 = nn.Linear(hidden_features, out_features)
41
+ self.drop = nn.Dropout(drop)
42
+
43
+ def forward(self, x, train: bool = True):
44
+ x = self.fc1(x)
45
+ x = self.act(x)
46
+ x = self.drop(x) if train else x
47
+ x = self.fc2(x)
48
+ x = self.drop(x) if train else x
49
+ return x
50
+
51
+
52
+ class Attention(nn.Module):
53
+ """
54
+ Default multihead attention
55
+ """
56
+
57
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
58
+ super().__init__()
59
+ self.dim = dim
60
+ self.num_heads = num_heads
61
+ self.head_dim = dim // num_heads
62
+ self.scale = self.head_dim**-0.5
63
+
64
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
65
+ self.attn_drop = nn.Dropout(attn_drop)
66
+ self.proj = nn.Linear(dim, dim)
67
+ self.proj_drop = nn.Dropout(proj_drop)
68
+
69
+ nn.init.xavier_uniform_(self.qkv.weight)
70
+ nn.init.xavier_uniform_(self.proj.weight)
71
+
72
+ def forward(self, x, train: bool = True):
73
+ B, N, C = x.shape
74
+ qkv = (
75
+ self.qkv(x)
76
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
77
+ .permute(2, 0, 3, 1, 4)
78
+ )
79
+ q, k, v = qkv[0], qkv[1], qkv[2]
80
+
81
+ attn = (q @ k.transpose(-2, -1)) * self.scale
82
+ attn = attn.softmax(dim=-1)
83
+ attn = self.attn_drop(attn) if train else attn
84
+
85
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
86
+ x = self.proj(x)
87
+ x = self.proj_drop(x) if train else x
88
+ return x
89
+
90
+
91
+ def window_partition1d(x, window_size):
92
+ B, W, C = x.shape
93
+ x = x.view(B, W // window_size, window_size, C)
94
+ windows = x.view(-1, window_size, C)
95
+ return windows
96
+
97
+
98
+ def window_reverse1d(windows, window_size, W: int):
99
+ B = int(windows.shape[0] / (W / window_size))
100
+ x = windows.view(B, W // window_size, window_size, -1)
101
+ x = x.view(B, W, -1)
102
+ return x
103
+
104
+
105
+ def get_relative_position_index1d(win_w):
106
+ # get pair-wise relative position index for each token inside the window
107
+ coords = torch.stack(torch.meshgrid(torch.arange(win_w)))
108
+
109
+ relative_coords = coords[:, :, None] - coords[:, None, :] # 1, Ww, Ww
110
+ relative_coords = relative_coords.permute(1, 2, 0) # Ww, Ww, 1
111
+
112
+ relative_coords[:, :, 0] += win_w - 1 # shift to start from 0
113
+
114
+ return relative_coords.sum(-1) # Ww*Ww
115
+
116
+
117
+ class WindowedAttentionHead(nn.Module):
118
+ def __init__(self, head_dim, window_size, shift_windows=False, attn_drop=0.0):
119
+ super().__init__()
120
+ self.head_dim = head_dim
121
+ self.window_size = window_size
122
+ self.shift_windows = shift_windows
123
+ self.attn_drop = attn_drop
124
+
125
+ self.scale = self.head_dim**-0.5
126
+ self.window_area = self.window_size * 1
127
+
128
+ self.relative_position_bias_table = nn.Parameter(
129
+ torch.zeros((2 * window_size - 1, 1))
130
+ )
131
+ nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
132
+
133
+ # Get relative position index
134
+ self.register_buffer(
135
+ "relative_position_index", get_relative_position_index1d(window_size)
136
+ )
137
+
138
+ self.drop_layer = nn.Dropout(attn_drop) if attn_drop > 0 else None
139
+
140
+ if shift_windows:
141
+ self.shift_size = window_size // 2
142
+ else:
143
+ self.shift_size = 0
144
+ assert 0 <= self.shift_size < self.window_size, (
145
+ "shift_size must in 0-window_size"
146
+ )
147
+
148
+ def forward(self, q, k, v, train: bool = True):
149
+ B, W, C = q.shape
150
+
151
+ mask = None
152
+ if self.shift_size > 0:
153
+ img_mask = torch.zeros((1, W, 1), device=q.device)
154
+ cnt = 0
155
+ for w in (
156
+ slice(0, -self.window_size),
157
+ slice(-self.window_size, -self.shift_size),
158
+ slice(-self.shift_size, None),
159
+ ):
160
+ img_mask[:, w, :] = cnt
161
+ cnt += 1
162
+ mask_windows = window_partition1d(img_mask, self.window_size)
163
+ mask_windows = mask_windows.view(-1, self.window_size)
164
+ mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
165
+ mask = mask.masked_fill(mask != 0, -100.0).masked_fill(mask == 0, 0.0)
166
+
167
+ q = torch.roll(q, shifts=-self.shift_size, dims=1)
168
+ k = torch.roll(k, shifts=-self.shift_size, dims=1)
169
+ v = torch.roll(v, shifts=-self.shift_size, dims=1)
170
+
171
+ q = window_partition1d(q, self.window_size)
172
+ k = window_partition1d(k, self.window_size)
173
+ v = window_partition1d(v, self.window_size)
174
+
175
+ attn = (q @ k.transpose(-2, -1)) * self.scale
176
+
177
+ if train:
178
+ attn = attn + self._get_rel_pos_bias()
179
+ else:
180
+ attn = attn + self._get_rel_pos_bias()
181
+
182
+ if mask is not None:
183
+ B_, N, _ = attn.shape
184
+ num_win = mask.shape[0]
185
+ attn = attn.view(B_ // num_win, num_win, N, N) + mask.unsqueeze(0)
186
+ attn = attn.view(-1, N, N)
187
+ attn = attn.softmax(dim=-1)
188
+ else:
189
+ attn = attn.softmax(dim=-1)
190
+
191
+ if self.drop_layer is not None and train:
192
+ attn = self.drop_layer(attn)
193
+
194
+ x = attn @ v
195
+
196
+ # merge windows
197
+ shifted_x = window_reverse1d(x, self.window_size, W=W)
198
+
199
+ if self.shift_size > 0:
200
+ x = torch.roll(shifted_x, shifts=self.shift_size, dims=1)
201
+ else:
202
+ x = shifted_x
203
+
204
+ return x, attn
205
+
206
+ def _get_rel_pos_bias(self):
207
+ relative_position_bias = self.relative_position_bias_table[
208
+ self.relative_position_index.view(-1)
209
+ ].view(self.window_area, self.window_area, -1) # Ww,Ww,1
210
+ relative_position_bias = relative_position_bias.permute(2, 0, 1) # 1, Ww, Ww
211
+ return relative_position_bias
212
+
213
+
214
+ class AttentionHead(nn.Module):
215
+ def __init__(self, head_dim, attn_drop=0.0):
216
+ super().__init__()
217
+ self.head_dim = head_dim
218
+ self.scale = head_dim**-0.5
219
+ self.drop_layer = nn.Dropout(attn_drop) if attn_drop > 0 else None
220
+
221
+ def forward(self, q, k, v, train: bool = True):
222
+ attn = (q @ k.transpose(-2, -1)) * self.scale
223
+ attn = attn.softmax(dim=-1)
224
+
225
+ if self.drop_layer is not None and train:
226
+ attn = self.drop_layer(attn)
227
+
228
+ x = attn @ v
229
+ return x, attn
230
+
231
+
232
+ class WindowedMultiHeadAttention(nn.Module):
233
+ def __init__(
234
+ self,
235
+ dim,
236
+ window_sizes,
237
+ shift_windows=False,
238
+ num_heads=8,
239
+ qkv_bias=False,
240
+ attn_drop=0.0,
241
+ proj_drop=0.0,
242
+ ):
243
+ super().__init__()
244
+ self.dim = dim
245
+ self.num_heads = num_heads
246
+ self.head_dim = dim // num_heads
247
+
248
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
249
+ nn.init.xavier_uniform_(self.qkv.weight)
250
+
251
+ if isinstance(window_sizes, int):
252
+ window_sizes = _ntuple(num_heads)(window_sizes)
253
+ else:
254
+ assert len(window_sizes) == num_heads
255
+
256
+ self.attn_heads = nn.ModuleList()
257
+ for i in range(num_heads):
258
+ ws_i = window_sizes[i]
259
+ if ws_i == 0:
260
+ self.attn_heads.append(AttentionHead(self.head_dim, attn_drop))
261
+ else:
262
+ self.attn_heads.append(
263
+ WindowedAttentionHead(
264
+ self.head_dim,
265
+ window_size=ws_i,
266
+ shift_windows=shift_windows,
267
+ attn_drop=attn_drop,
268
+ )
269
+ )
270
+
271
+ self.proj = nn.Linear(dim, dim)
272
+ nn.init.xavier_uniform_(self.proj.weight)
273
+ self.drop_layer = nn.Dropout(proj_drop) if proj_drop > 0 else None
274
+
275
+ def forward(self, x, train: bool = True):
276
+ B, N, C = x.shape
277
+ qkv = (
278
+ self.qkv(x)
279
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
280
+ .permute(2, 3, 0, 1, 4)
281
+ )
282
+ q, k, v = qkv[0], qkv[1], qkv[2]
283
+
284
+ o = []
285
+ for i in range(self.num_heads):
286
+ head_i, attn_i = self.attn_heads[i](q[i], k[i], v[i], train=train)
287
+ o.append(head_i.unsqueeze(0))
288
+
289
+ o = torch.cat(o, dim=0)
290
+ o = o.permute(1, 2, 0, 3).reshape(B, N, -1)
291
+ o = self.proj(o)
292
+
293
+ if self.drop_layer is not None and train:
294
+ o = self.drop_layer(o)
295
+
296
+ return o
297
+
298
+
299
+ class LayerScale(nn.Module):
300
+ def __init__(self, dim, init_values=1e-5):
301
+ super().__init__()
302
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
303
+
304
+ def forward(self, x):
305
+ return x * self.gamma
306
+
307
+
308
+ class BNWrapper(nn.Module):
309
+ def __init__(
310
+ self, num_features, use_running_average=True, use_bias=True, use_scale=True
311
+ ):
312
+ super().__init__()
313
+ self.bn = nn.BatchNorm1d(num_features, affine=use_scale or use_bias)
314
+
315
+ def forward(self, x, train=True):
316
+ return self.bn(x, train)
317
+
318
+
319
+ class Block(nn.Module):
320
+ def __init__(
321
+ self,
322
+ dim,
323
+ num_heads,
324
+ mlp_ratio=4.0,
325
+ qkv_bias=False,
326
+ drop=0.0,
327
+ attn_drop=0.0,
328
+ init_values=None,
329
+ drop_path=0.0,
330
+ act_layer=F.gelu,
331
+ norm_layer=nn.LayerNorm,
332
+ ):
333
+ super().__init__()
334
+ self.norm1 = norm_layer(dim)
335
+ self.attn = Attention(
336
+ dim,
337
+ num_heads=num_heads,
338
+ qkv_bias=qkv_bias,
339
+ attn_drop=attn_drop,
340
+ proj_drop=drop,
341
+ )
342
+
343
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
344
+ self.norm2 = norm_layer(dim)
345
+ mlp_hidden_dim = int(dim * mlp_ratio)
346
+ self.mlp = Mlp(
347
+ in_features=dim,
348
+ hidden_features=mlp_hidden_dim,
349
+ out_features=dim,
350
+ activation=act_layer,
351
+ drop=drop,
352
+ )
353
+
354
+ self.init_values = init_values
355
+ if init_values is not None:
356
+ self.layer_scale1 = LayerScale(dim, init_values)
357
+ self.layer_scale2 = LayerScale(dim, init_values)
358
+
359
+ def forward(self, x, train: bool = True):
360
+ outputs1 = self.attn(self.norm1(x), train=train)
361
+
362
+ if self.init_values is not None:
363
+ outputs1 = self.layer_scale1(outputs1)
364
+
365
+ x = x + self.drop_path(outputs1) if train else x + outputs1
366
+
367
+ outputs2 = self.mlp(self.norm2(x), train=train)
368
+
369
+ if self.init_values is not None:
370
+ outputs2 = self.layer_scale2(outputs2)
371
+
372
+ x = x + self.drop_path(outputs2) if train else x + outputs2
373
+ return x
374
+
375
+
376
+ class MWMHABlock(nn.Module):
377
+ def __init__(
378
+ self,
379
+ dim,
380
+ num_heads,
381
+ window_sizes,
382
+ shift_windows=False,
383
+ mlp_ratio=4.0,
384
+ qkv_bias=False,
385
+ drop=0.0,
386
+ attn_drop=0.0,
387
+ init_values=None,
388
+ drop_path=0.0,
389
+ act_layer=F.gelu,
390
+ norm_layer=nn.LayerNorm,
391
+ ):
392
+ super().__init__()
393
+ self.norm1 = norm_layer(dim)
394
+ self.wmha = WindowedMultiHeadAttention(
395
+ dim,
396
+ window_sizes=window_sizes,
397
+ shift_windows=shift_windows,
398
+ num_heads=num_heads,
399
+ qkv_bias=qkv_bias,
400
+ attn_drop=attn_drop,
401
+ proj_drop=drop,
402
+ )
403
+
404
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
405
+ self.norm2 = norm_layer(dim)
406
+ mlp_hidden_dim = int(dim * mlp_ratio)
407
+ self.mlp = Mlp(
408
+ in_features=dim,
409
+ hidden_features=mlp_hidden_dim,
410
+ out_features=dim,
411
+ activation=act_layer,
412
+ drop=drop,
413
+ )
414
+
415
+ self.init_values = init_values
416
+ if init_values is not None:
417
+ self.layer_scale1 = LayerScale(dim, init_values)
418
+ self.layer_scale2 = LayerScale(dim, init_values)
419
+
420
+ def forward(self, x, train: bool = True):
421
+ outputs1 = self.wmha(self.norm1(x), train=train)
422
+
423
+ if self.init_values is not None:
424
+ outputs1 = self.layer_scale1(outputs1)
425
+
426
+ x = x + self.drop_path(outputs1) if train else x + outputs1
427
+
428
+ outputs2 = self.mlp(self.norm2(x), train=train)
429
+
430
+ if self.init_values is not None:
431
+ outputs2 = self.layer_scale2(outputs2)
432
+
433
+ x = x + self.drop_path(outputs2) if train else x + outputs2
434
+ return x
patching_utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ def generate_patches(input, fstride, tstride, fshape, tshape):
6
+ r"""Function that extract patches from tensors and stacks them.
7
+
8
+ See :class:`~kornia.contrib.ExtractTensorPatches` for details.
9
+
10
+ Args:
11
+ input: tensor image where to extract the patches with shape :math:`(B, C, H, W)`.
12
+
13
+ Returns:
14
+ the tensor with the extracted patches with shape :math:`(B, N, C, H_{out}, W_{out})`.
15
+
16
+ Examples:
17
+ >>> input = torch.arange(9.).view(1, 1, 3, 3)
18
+ >>> patches = extract_tensor_patches(input, (2, 3))
19
+ >>> input
20
+ tensor([[[[0., 1., 2.],
21
+ [3., 4., 5.],
22
+ [6., 7., 8.]]]])
23
+ >>> patches[:, -1]
24
+ tensor([[[[3., 4., 5.],
25
+ [6., 7., 8.]]]])
26
+
27
+ """
28
+ batch_size, num_channels = input.size()[:2]
29
+ dims = range(2, input.dim())
30
+ for dim, patch_size, stride in zip(dims, (fshape, tshape), (fstride, tstride)):
31
+ input = input.unfold(dim, patch_size, stride)
32
+ input = input.permute(0, *dims, 1, *(dim + len(dims) for dim in dims)).contiguous()
33
+ return input.view(batch_size, -1, num_channels, fshape, tshape)
34
+
35
+
36
+ def combine_patches(
37
+ patches,
38
+ original_size,
39
+ fstride,
40
+ tstride,
41
+ fshape,
42
+ tshape,
43
+ eps: float = 1e-8,
44
+ ):
45
+ r"""Restore input from patches.
46
+
47
+ See :class:`~kornia.contrib.CombineTensorPatches` for details.
48
+
49
+ Args:
50
+ patches: patched tensor with shape :math:`(B, N, C, H_{out}, W_{out})`.
51
+
52
+ Return:
53
+ The combined patches in an image tensor with shape :math:`(B, C, H, W)`.
54
+
55
+ Example:
56
+ >>> out = extract_tensor_patches(torch.arange(16).view(1, 1, 4, 4), window_size=(2, 2), stride=(2, 2))
57
+ >>> combine_tensor_patches(out, original_size=(4, 4), window_size=(2, 2), stride=(2, 2))
58
+ tensor([[[[ 0, 1, 2, 3],
59
+ [ 4, 5, 6, 7],
60
+ [ 8, 9, 10, 11],
61
+ [12, 13, 14, 15]]]])
62
+
63
+ .. note::
64
+ This function is supposed to be used in conjunction with :func:`extract_tensor_patches`.
65
+
66
+ """
67
+ if patches.ndim != 5:
68
+ raise ValueError(
69
+ f"Invalid input shape, we expect BxNxCxHxW. Got: {patches.shape}"
70
+ )
71
+ ones = torch.ones(
72
+ patches.shape[0],
73
+ patches.shape[2],
74
+ original_size[0],
75
+ original_size[1],
76
+ device=patches.device,
77
+ dtype=patches.dtype,
78
+ )
79
+ restored_size = ones.shape[2:]
80
+
81
+ patches = patches.permute(0, 2, 3, 4, 1)
82
+ patches = patches.reshape(patches.shape[0], -1, patches.shape[-1])
83
+ int_flag = 0
84
+ if not torch.is_floating_point(patches):
85
+ int_flag = 1
86
+ dtype = patches.dtype
87
+ patches = patches.float()
88
+ ones = ones.float()
89
+
90
+ # Calculate normalization map
91
+ unfold_ones = torch.nn.functional.unfold(
92
+ ones, kernel_size=(fshape, tshape), stride=(fstride, tstride)
93
+ )
94
+ norm_map = torch.nn.functional.fold(
95
+ input=unfold_ones,
96
+ output_size=restored_size,
97
+ kernel_size=(fshape, tshape),
98
+ stride=(fstride, tstride),
99
+ )
100
+ # Restored tensor
101
+ saturated_restored_tensor = torch.nn.functional.fold(
102
+ input=patches,
103
+ output_size=restored_size,
104
+ kernel_size=(fshape, tshape),
105
+ stride=(fstride, tstride),
106
+ )
107
+ # Remove satuation effect due to multiple summations
108
+ restored_tensor = saturated_restored_tensor / (norm_map + eps)
109
+ if int_flag:
110
+ restored_tensor = restored_tensor.to(dtype)
111
+ return restored_tensor
112
+
113
+
114
+ # get the shape of intermediate representation.
115
+ def get_shape(fstride, tstride, input_fdim, input_tdim, fshape, tshape):
116
+ test_input = torch.randn(1, 2, input_fdim, input_tdim)
117
+ test_proj = nn.Conv2d(
118
+ 2,
119
+ 2,
120
+ kernel_size=(fshape, tshape),
121
+ stride=(fstride, tstride),
122
+ )
123
+ test_out = test_proj(test_input)
124
+ f_dim = test_out.shape[2]
125
+ t_dim = test_out.shape[3]
126
+ return f_dim, t_dim
pos_embed.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # Position embedding utils
8
+ # --------------------------------------------------------
9
+
10
+
11
+ # https://github.com/facebookresearch/AudioMAE/blob/main/util/pos_embed.py
12
+ import numpy as np
13
+ import torch
14
+
15
+
16
+ # --------------------------------------------------------
17
+ # 2D sine-cosine position embedding
18
+ # References:
19
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
20
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
21
+ # --------------------------------------------------------
22
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token_num):
23
+ """
24
+ grid_size: int of the grid height and width
25
+ return:
26
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
27
+ """
28
+ if grid_size is int:
29
+ gH = grid_size
30
+ gW = grid_size
31
+ else:
32
+ gH = grid_size[0]
33
+ gW = grid_size[1]
34
+ grid_h = np.arange(gH, dtype=np.float64)
35
+ grid_w = np.arange(gW, dtype=np.float64)
36
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
37
+ grid = np.stack(grid, axis=0)
38
+
39
+ grid = grid.reshape([2, 1, gH, gW])
40
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
41
+ for _ in range(cls_token_num):
42
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
43
+ return pos_embed
44
+
45
+
46
+ def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False):
47
+ """
48
+ grid_size: int of the grid height and width
49
+ return:
50
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
51
+ """
52
+ grid_h = np.arange(grid_size[0], dtype=np.float64)
53
+ grid_w = np.arange(grid_size[1], dtype=np.float64)
54
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
55
+ grid = np.stack(grid, axis=0)
56
+
57
+ grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
58
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
59
+ if cls_token:
60
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
61
+ return pos_embed
62
+
63
+
64
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
65
+ assert embed_dim % 2 == 0
66
+
67
+ # use half of dimensions to encode grid_h
68
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
69
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
70
+
71
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
72
+ return emb
73
+
74
+
75
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
76
+ """
77
+ embed_dim: output dimension for each position
78
+ pos: a list of positions to be encoded: size (M,)
79
+ out: (M, D)
80
+ """
81
+ assert embed_dim % 2 == 0
82
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
83
+ omega /= embed_dim / 2.0
84
+ omega = 1.0 / 10000**omega # (D/2,)
85
+
86
+ pos = pos.reshape(-1) # (M,)
87
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
88
+
89
+ emb_sin = np.sin(out) # (M, D/2)
90
+ emb_cos = np.cos(out) # (M, D/2)
91
+
92
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
93
+ return emb
94
+
95
+
96
+ # --------------------------------------------------------
97
+ # Interpolate position embeddings for high-resolution
98
+ # References:
99
+ # DeiT: https://github.com/facebookresearch/deit
100
+ # --------------------------------------------------------
101
+ def interpolate_pos_embed(model, checkpoint_model):
102
+ if "pos_embed" in checkpoint_model:
103
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
104
+ embedding_size = pos_embed_checkpoint.shape[-1]
105
+ num_patches = model.patch_embed.num_patches
106
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
107
+ # height (== width) for the checkpoint position embedding
108
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
109
+ # height (== width) for the new position embedding
110
+ new_size = int(num_patches**0.5)
111
+ # class_token and dist_token are kept unchanged
112
+ if orig_size != new_size:
113
+ print(
114
+ "Position interpolate from %dx%d to %dx%d"
115
+ % (orig_size, orig_size, new_size, new_size)
116
+ )
117
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
118
+ # only the position tokens are interpolated
119
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
120
+ pos_tokens = pos_tokens.reshape(
121
+ -1, orig_size, orig_size, embedding_size
122
+ ).permute(0, 3, 1, 2)
123
+ pos_tokens = torch.nn.functional.interpolate(
124
+ pos_tokens,
125
+ size=(new_size, new_size),
126
+ mode="bicubic",
127
+ align_corners=False,
128
+ )
129
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
130
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
131
+ checkpoint_model["pos_embed"] = new_pos_embed
132
+
133
+
134
+ def interpolate_pos_embed_img2audio(model, checkpoint_model, orig_size, new_size):
135
+ if "pos_embed" in checkpoint_model:
136
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
137
+ embedding_size = pos_embed_checkpoint.shape[-1]
138
+ num_patches = model.patch_embed.num_patches
139
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
140
+ # height (== width) for the checkpoint position embedding
141
+ # orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
142
+ # height (== width) for the new position embedding
143
+ # new_size = int(num_patches ** 0.5)
144
+ # class_token and dist_token are kept unchanged
145
+ if orig_size != new_size:
146
+ print(
147
+ "Position interpolate from %dx%d to %dx%d"
148
+ % (orig_size[0], orig_size[1], new_size[0], new_size[1])
149
+ )
150
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
151
+ # only the position tokens are interpolated
152
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
153
+ pos_tokens = pos_tokens.reshape(
154
+ -1, orig_size[0], orig_size[1], embedding_size
155
+ ).permute(0, 3, 1, 2)
156
+ pos_tokens = torch.nn.functional.interpolate(
157
+ pos_tokens,
158
+ size=(new_size[0], new_size[1]),
159
+ mode="bicubic",
160
+ align_corners=False,
161
+ )
162
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
163
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
164
+ checkpoint_model["pos_embed"] = new_pos_embed
165
+
166
+
167
+ def interpolate_pos_embed_audio(model, checkpoint_model, orig_size, new_size):
168
+ if "pos_embed" in checkpoint_model:
169
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
170
+ embedding_size = pos_embed_checkpoint.shape[-1]
171
+ if orig_size != new_size:
172
+ print(
173
+ "Position interpolate from %dx%d to %dx%d"
174
+ % (orig_size[0], orig_size[1], new_size[0], new_size[1])
175
+ )
176
+ # extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
177
+ # only the position tokens are interpolated
178
+ cls_token = pos_embed_checkpoint[:, 0, :].unsqueeze(1)
179
+ pos_tokens = pos_embed_checkpoint[:, 1:, :] # remove
180
+ pos_tokens = pos_tokens.reshape(
181
+ -1, orig_size[0], orig_size[1], embedding_size
182
+ ) # .permute(0, 3, 1, 2)
183
+ # pos_tokens = torch.nn.functional.interpolate(
184
+ # pos_tokens, size=(new_size[0], new_size[1]), mode='bicubic', align_corners=False)
185
+
186
+ # pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
187
+ pos_tokens = pos_tokens[:, :, : new_size[1], :] # assume only time diff
188
+ pos_tokens = pos_tokens.flatten(1, 2)
189
+ new_pos_embed = torch.cat((cls_token, pos_tokens), dim=1)
190
+ checkpoint_model["pos_embed"] = new_pos_embed
191
+
192
+
193
+ def interpolate_patch_embed_audio(
194
+ model,
195
+ checkpoint_model,
196
+ orig_channel,
197
+ new_channel=1,
198
+ kernel_size=(16, 16),
199
+ stride=(16, 16),
200
+ padding=(0, 0),
201
+ ):
202
+ if orig_channel != new_channel:
203
+ if "patch_embed.proj.weight" in checkpoint_model:
204
+ # aggregate 3 channels in rgb ckpt to 1 channel for audio
205
+ new_proj_weight = torch.nn.Parameter(
206
+ torch.sum(checkpoint_model["patch_embed.proj.weight"], dim=1).unsqueeze(
207
+ 1
208
+ )
209
+ )
210
+ checkpoint_model["patch_embed.proj.weight"] = new_proj_weight
swin.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # SimMIM
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu
6
+ # Modified by Zhenda Xie
7
+ # --------------------------------------------------------
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint as checkpoint
12
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
13
+
14
+
15
+ class Mlp(nn.Module):
16
+ def __init__(
17
+ self,
18
+ in_features,
19
+ hidden_features=None,
20
+ out_features=None,
21
+ act_layer=nn.GELU,
22
+ drop=0.0,
23
+ ):
24
+ super().__init__()
25
+ out_features = out_features or in_features
26
+ hidden_features = hidden_features or in_features
27
+ self.fc1 = nn.Linear(in_features, hidden_features)
28
+ self.act = act_layer()
29
+ self.fc2 = nn.Linear(hidden_features, out_features)
30
+ self.drop = nn.Dropout(drop)
31
+
32
+ def forward(self, x):
33
+ x = self.fc1(x)
34
+ x = self.act(x)
35
+ x = self.drop(x)
36
+ x = self.fc2(x)
37
+ x = self.drop(x)
38
+ return x
39
+
40
+
41
+ def window_partition(x, window_size):
42
+ """
43
+ Args:
44
+ x: (B, H, W, C)
45
+ window_size (int): window size
46
+
47
+ Returns:
48
+ windows: (num_windows*B, window_size, window_size, C)
49
+ """
50
+ B, H, W, C = x.shape
51
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
52
+ windows = (
53
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
54
+ )
55
+ return windows
56
+
57
+
58
+ def window_reverse(windows, window_size, H, W):
59
+ """
60
+ Args:
61
+ windows: (num_windows*B, window_size, window_size, C)
62
+ window_size (int): Window size
63
+ H (int): Height of image
64
+ W (int): Width of image
65
+
66
+ Returns:
67
+ x: (B, H, W, C)
68
+ """
69
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
70
+ x = windows.view(
71
+ B, H // window_size, W // window_size, window_size, window_size, -1
72
+ )
73
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
74
+ return x
75
+
76
+
77
+ class WindowAttention(nn.Module):
78
+ r"""Window based multi-head self attention (W-MSA) module with relative position bias.
79
+ It supports both of shifted and non-shifted window.
80
+
81
+ Args:
82
+ dim (int): Number of input channels.
83
+ window_size (tuple[int]): The height and width of the window.
84
+ num_heads (int): Number of attention heads.
85
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
86
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
87
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
88
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ dim,
94
+ window_size,
95
+ num_heads,
96
+ qkv_bias=True,
97
+ qk_scale=None,
98
+ attn_drop=0.0,
99
+ proj_drop=0.0,
100
+ ):
101
+ super().__init__()
102
+ self.dim = dim
103
+ self.window_size = window_size # Wh, Ww
104
+ self.num_heads = num_heads
105
+ head_dim = dim // num_heads
106
+ self.scale = qk_scale or head_dim**-0.5
107
+
108
+ # define a parameter table of relative position bias
109
+ self.relative_position_bias_table = nn.Parameter(
110
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
111
+ ) # 2*Wh-1 * 2*Ww-1, nH
112
+
113
+ # get pair-wise relative position index for each token inside the window
114
+ coords_h = torch.arange(self.window_size[0])
115
+ coords_w = torch.arange(self.window_size[1])
116
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
117
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
118
+ relative_coords = (
119
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
120
+ ) # 2, Wh*Ww, Wh*Ww
121
+ relative_coords = relative_coords.permute(
122
+ 1, 2, 0
123
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
124
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
125
+ relative_coords[:, :, 1] += self.window_size[1] - 1
126
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
127
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
128
+ self.register_buffer("relative_position_index", relative_position_index)
129
+
130
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
131
+ self.attn_drop = nn.Dropout(attn_drop)
132
+ self.proj = nn.Linear(dim, dim)
133
+ self.proj_drop = nn.Dropout(proj_drop)
134
+
135
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
136
+ self.softmax = nn.Softmax(dim=-1)
137
+
138
+ def forward(self, x, mask=None):
139
+ """
140
+ Args:
141
+ x: input features with shape of (num_windows*B, N, C)
142
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
143
+ """
144
+ B_, N, C = x.shape
145
+ qkv = (
146
+ self.qkv(x)
147
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
148
+ .permute(2, 0, 3, 1, 4)
149
+ )
150
+ q, k, v = (
151
+ qkv[0],
152
+ qkv[1],
153
+ qkv[2],
154
+ ) # make torchscript happy (cannot use tensor as tuple)
155
+
156
+ q = q * self.scale
157
+ attn = q @ k.transpose(-2, -1)
158
+
159
+ relative_position_bias = self.relative_position_bias_table[
160
+ self.relative_position_index.view(-1)
161
+ ].view(
162
+ self.window_size[0] * self.window_size[1],
163
+ self.window_size[0] * self.window_size[1],
164
+ -1,
165
+ ) # Wh*Ww,Wh*Ww,nH
166
+ relative_position_bias = relative_position_bias.permute(
167
+ 2, 0, 1
168
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
169
+ attn = attn + relative_position_bias.unsqueeze(0)
170
+
171
+ if mask is not None:
172
+ nW = mask.shape[0]
173
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
174
+ 1
175
+ ).unsqueeze(0)
176
+ attn = attn.view(-1, self.num_heads, N, N)
177
+ attn = self.softmax(attn)
178
+ else:
179
+ attn = self.softmax(attn)
180
+
181
+ attn = self.attn_drop(attn)
182
+
183
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
184
+ x = self.proj(x)
185
+ x = self.proj_drop(x)
186
+ return x
187
+
188
+ def extra_repr(self) -> str:
189
+ return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
190
+
191
+ def flops(self, N):
192
+ # calculate flops for 1 window with token length of N
193
+ flops = 0
194
+ # qkv = self.qkv(x)
195
+ flops += N * self.dim * 3 * self.dim
196
+ # attn = (q @ k.transpose(-2, -1))
197
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
198
+ # x = (attn @ v)
199
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
200
+ # x = self.proj(x)
201
+ flops += N * self.dim * self.dim
202
+ return flops
203
+
204
+
205
+ class SwinTransformerBlock(nn.Module):
206
+ r"""Swin Transformer Block.
207
+
208
+ Args:
209
+ dim (int): Number of input channels.
210
+ input_resolution (tuple[int]): Input resulotion.
211
+ num_heads (int): Number of attention heads.
212
+ window_size (int): Window size.
213
+ shift_size (int): Shift size for SW-MSA.
214
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
215
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
216
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
217
+ drop (float, optional): Dropout rate. Default: 0.0
218
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
219
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
220
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
221
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
222
+ """
223
+
224
+ def __init__(
225
+ self,
226
+ dim,
227
+ input_resolution,
228
+ num_heads,
229
+ window_size=7,
230
+ shift_size=0,
231
+ mlp_ratio=4.0,
232
+ qkv_bias=True,
233
+ qk_scale=None,
234
+ drop=0.0,
235
+ attn_drop=0.0,
236
+ drop_path=0.0,
237
+ act_layer=nn.GELU,
238
+ norm_layer=nn.LayerNorm,
239
+ ):
240
+ super().__init__()
241
+ self.dim = dim
242
+ self.input_resolution = input_resolution
243
+ self.num_heads = num_heads
244
+ self.window_size = window_size
245
+ self.shift_size = shift_size
246
+ self.mlp_ratio = mlp_ratio
247
+ if min(self.input_resolution) <= self.window_size:
248
+ # if window size is larger than input resolution, we don't partition windows
249
+ self.shift_size = 0
250
+ self.window_size = min(self.input_resolution)
251
+ assert 0 <= self.shift_size < self.window_size, (
252
+ "shift_size must in 0-window_size"
253
+ )
254
+
255
+ self.norm1 = norm_layer(dim)
256
+ self.attn = WindowAttention(
257
+ dim,
258
+ window_size=to_2tuple(self.window_size),
259
+ num_heads=num_heads,
260
+ qkv_bias=qkv_bias,
261
+ qk_scale=qk_scale,
262
+ attn_drop=attn_drop,
263
+ proj_drop=drop,
264
+ )
265
+
266
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
267
+ self.norm2 = norm_layer(dim)
268
+ mlp_hidden_dim = int(dim * mlp_ratio)
269
+ self.mlp = Mlp(
270
+ in_features=dim,
271
+ hidden_features=mlp_hidden_dim,
272
+ act_layer=act_layer,
273
+ drop=drop,
274
+ )
275
+
276
+ if self.shift_size > 0:
277
+ # calculate attention mask for SW-MSA
278
+ H, W = self.input_resolution
279
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
280
+ h_slices = (
281
+ slice(0, -self.window_size),
282
+ slice(-self.window_size, -self.shift_size),
283
+ slice(-self.shift_size, None),
284
+ )
285
+ w_slices = (
286
+ slice(0, -self.window_size),
287
+ slice(-self.window_size, -self.shift_size),
288
+ slice(-self.shift_size, None),
289
+ )
290
+ cnt = 0
291
+ for h in h_slices:
292
+ for w in w_slices:
293
+ img_mask[:, h, w, :] = cnt
294
+ cnt += 1
295
+
296
+ mask_windows = window_partition(
297
+ img_mask, self.window_size
298
+ ) # nW, window_size, window_size, 1
299
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
300
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
301
+ attn_mask = attn_mask.masked_fill(
302
+ attn_mask != 0, float(-100.0)
303
+ ).masked_fill(attn_mask == 0, float(0.0))
304
+ else:
305
+ attn_mask = None
306
+
307
+ self.register_buffer("attn_mask", attn_mask)
308
+
309
+ def forward(self, x):
310
+ H, W = self.input_resolution
311
+ B, L, C = x.shape
312
+ assert L == H * W, "input feature has wrong size"
313
+
314
+ shortcut = x
315
+ x = self.norm1(x)
316
+ x = x.view(B, H, W, C)
317
+
318
+ # cyclic shift
319
+ if self.shift_size > 0:
320
+ shifted_x = torch.roll(
321
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
322
+ )
323
+ else:
324
+ shifted_x = x
325
+
326
+ # partition windows
327
+ x_windows = window_partition(
328
+ shifted_x, self.window_size
329
+ ) # nW*B, window_size, window_size, C
330
+ x_windows = x_windows.view(
331
+ -1, self.window_size * self.window_size, C
332
+ ) # nW*B, window_size*window_size, C
333
+
334
+ # W-MSA/SW-MSA
335
+ attn_windows = self.attn(
336
+ x_windows, mask=self.attn_mask
337
+ ) # nW*B, window_size*window_size, C
338
+
339
+ # merge windows
340
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
341
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
342
+
343
+ # reverse cyclic shift
344
+ if self.shift_size > 0:
345
+ x = torch.roll(
346
+ shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
347
+ )
348
+ else:
349
+ x = shifted_x
350
+ x = x.view(B, H * W, C)
351
+
352
+ # FFN
353
+ x = shortcut + self.drop_path(x)
354
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
355
+
356
+ return x
357
+
358
+ def extra_repr(self) -> str:
359
+ return (
360
+ f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
361
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
362
+ )
363
+
364
+ def flops(self):
365
+ flops = 0
366
+ H, W = self.input_resolution
367
+ # norm1
368
+ flops += self.dim * H * W
369
+ # W-MSA/SW-MSA
370
+ nW = H * W / self.window_size / self.window_size
371
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
372
+ # mlp
373
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
374
+ # norm2
375
+ flops += self.dim * H * W
376
+ return flops
377
+
378
+
379
+ class PatchMerging(nn.Module):
380
+ r"""Patch Merging Layer.
381
+
382
+ Args:
383
+ input_resolution (tuple[int]): Resolution of input feature.
384
+ dim (int): Number of input channels.
385
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
386
+ """
387
+
388
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
389
+ super().__init__()
390
+ self.input_resolution = input_resolution
391
+ self.dim = dim
392
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
393
+ self.norm = norm_layer(4 * dim)
394
+
395
+ def forward(self, x):
396
+ """
397
+ x: B, H*W, C
398
+ """
399
+ H, W = self.input_resolution
400
+ B, L, C = x.shape
401
+ assert L == H * W, "input feature has wrong size"
402
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
403
+
404
+ x = x.view(B, H, W, C)
405
+
406
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
407
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
408
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
409
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
410
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
411
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
412
+
413
+ x = self.norm(x)
414
+ x = self.reduction(x)
415
+
416
+ return x
417
+
418
+ def extra_repr(self) -> str:
419
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
420
+
421
+ def flops(self):
422
+ H, W = self.input_resolution
423
+ flops = H * W * self.dim
424
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
425
+ return flops
426
+
427
+
428
+ class BasicLayer(nn.Module):
429
+ """A basic Swin Transformer layer for one stage.
430
+
431
+ Args:
432
+ dim (int): Number of input channels.
433
+ input_resolution (tuple[int]): Input resolution.
434
+ depth (int): Number of blocks.
435
+ num_heads (int): Number of attention heads.
436
+ window_size (int): Local window size.
437
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
438
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
439
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
440
+ drop (float, optional): Dropout rate. Default: 0.0
441
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
442
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
443
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
444
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
445
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
446
+ """
447
+
448
+ def __init__(
449
+ self,
450
+ dim,
451
+ input_resolution,
452
+ depth,
453
+ num_heads,
454
+ window_size,
455
+ mlp_ratio=4.0,
456
+ qkv_bias=True,
457
+ qk_scale=None,
458
+ drop=0.0,
459
+ attn_drop=0.0,
460
+ drop_path=0.0,
461
+ norm_layer=nn.LayerNorm,
462
+ downsample=None,
463
+ use_checkpoint=False,
464
+ ):
465
+ super().__init__()
466
+ self.dim = dim
467
+ self.input_resolution = input_resolution
468
+ self.depth = depth
469
+ self.use_checkpoint = use_checkpoint
470
+
471
+ # build blocks
472
+ self.blocks = nn.ModuleList(
473
+ [
474
+ SwinTransformerBlock(
475
+ dim=dim,
476
+ input_resolution=input_resolution,
477
+ num_heads=num_heads,
478
+ window_size=window_size,
479
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
480
+ mlp_ratio=mlp_ratio,
481
+ qkv_bias=qkv_bias,
482
+ qk_scale=qk_scale,
483
+ drop=drop,
484
+ attn_drop=attn_drop,
485
+ drop_path=drop_path[i]
486
+ if isinstance(drop_path, list)
487
+ else drop_path,
488
+ norm_layer=norm_layer,
489
+ )
490
+ for i in range(depth)
491
+ ]
492
+ )
493
+
494
+ # patch merging layer
495
+ if downsample is not None:
496
+ self.downsample = downsample(
497
+ input_resolution, dim=dim, norm_layer=norm_layer
498
+ )
499
+ else:
500
+ self.downsample = None
501
+
502
+ def forward(self, x):
503
+ print("IN", x.shape)
504
+ for blk in self.blocks:
505
+ if self.use_checkpoint:
506
+ x = checkpoint.checkpoint(blk, x)
507
+ else:
508
+ x = blk(x)
509
+ if self.downsample is not None:
510
+ x = self.downsample(x)
511
+ return x
512
+
513
+ def extra_repr(self) -> str:
514
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
515
+
516
+ def flops(self):
517
+ flops = 0
518
+ for blk in self.blocks:
519
+ flops += blk.flops()
520
+ if self.downsample is not None:
521
+ flops += self.downsample.flops()
522
+ return flops
utils.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+ import math
3
+ import sys
4
+ from itertools import repeat
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import timm
9
+ import torch
10
+ from torch import nn
11
+ from torchvision.models.vision_transformer import Encoder
12
+
13
+
14
+ from typing import Tuple
15
+ from functools import partial
16
+ from collections.abc import Iterable # import directly from collections for Python < 3.3
17
+
18
+
19
+ def plot_fbank(fbank, title=None, save_path=None, **kwargs):
20
+ fig, axs = plt.subplots(min(4, fbank.shape[0]), 1, sharex=True, sharey=True)
21
+ if not isinstance(axs, Iterable):
22
+ axs = np.array([axs])
23
+ vmin, vmax = kwargs.get("vmin", None), kwargs.get("vmax", None)
24
+ # max 4 channels...
25
+ for channel in range(0, min(4, fbank.shape[0])):
26
+ axs[channel].set_title(f"Filter bank channel {channel}, {title}")
27
+ im = axs[channel].imshow(fbank[channel].T, aspect="auto", vmin=vmin, vmax=vmax)
28
+ axs[channel].set_ylabel("mel")
29
+ axs[channel].set_xlabel("time")
30
+ plt.gca().invert_yaxis()
31
+ plt.tight_layout()
32
+ fig.colorbar(im, ax=axs.ravel().tolist())
33
+ plt.show()
34
+ if save_path:
35
+ fig.savefig(save_path)
36
+ plt.close()
37
+ return fig
38
+
39
+
40
+ # From PyTorch Internals to create the tuples of the given iterable.
41
+ def _ntuple(n):
42
+ def parse(x):
43
+ # if x is already an instance of iterable object, create a tuple out of it
44
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
45
+ return tuple(x)
46
+ # Otherwise repeat the x, n times, and create a tuple.
47
+ return tuple(repeat(x, n))
48
+
49
+ return parse
50
+
51
+
52
+ class PatchEmbed(nn.Module):
53
+ """Image to Patch Embedding"""
54
+
55
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
56
+ super().__init__()
57
+ img_size = _ntuple(2)(img_size)
58
+ patch_size = _ntuple(2)(patch_size)
59
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
60
+ self.img_size = img_size
61
+ self.patch_size = patch_size
62
+ self.num_patches = num_patches
63
+
64
+ self.proj = nn.Conv2d(
65
+ in_channels=in_chans,
66
+ out_channels=embed_dim,
67
+ kernel_size=patch_size,
68
+ stride=patch_size,
69
+ )
70
+
71
+ # We need to override these.
72
+ def forward(self, x):
73
+ x = self.proj(x).flatten(2).transpose(1, 2)
74
+ return x
75
+
76
+
77
+ def get_sinusoid_encoding(n_position, d_hid):
78
+ """Sinusoid position encoding table"""
79
+
80
+ def get_position_angle_vec(position):
81
+ return [
82
+ position / np.power(10000, 2 * (hid_j // 2) / d_hid)
83
+ for hid_j in range(d_hid)
84
+ ]
85
+
86
+ sinusoid_table = np.array(
87
+ [get_position_angle_vec(pos_i) for pos_i in range(n_position)]
88
+ )
89
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
90
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
91
+
92
+ return torch.FloatTensor(sinusoid_table).unsqueeze(0)
93
+
94
+
95
+ def create_pretrained_model(model_size,
96
+ encoder_num_layers = 12,
97
+ encoder_num_heads = 12,
98
+ encoder_hidden_dim = 768,
99
+ encoder_mlp_dim= 3072,
100
+ encoder_dropout = 0.0,
101
+ encoder_attention_dropout = 0.0,
102
+ encoder_norm_layer_eps = 1e-6):
103
+ if model_size == "tiny":
104
+ v = timm.create_model("deit_tiny_distilled_patch16_224", pretrained=False)
105
+ hidden_dim = 182
106
+
107
+ elif model_size == "small":
108
+ v = timm.create_model("deit_small_distilled_patch16_224", pretrained=False)
109
+ hidden_dim = 384
110
+
111
+ elif model_size == "base":
112
+ v = Encoder(
113
+ seq_length = 0, #Only used for pos_embeddings and we set them later!
114
+ num_layers = encoder_num_layers,
115
+ num_heads = encoder_num_heads,
116
+ hidden_dim = encoder_hidden_dim,
117
+ mlp_dim= encoder_mlp_dim,
118
+ dropout = encoder_dropout,
119
+ attention_dropout = encoder_attention_dropout,
120
+ norm_layer = partial(nn.LayerNorm, eps=encoder_norm_layer_eps))
121
+ hidden_dim = encoder_hidden_dim
122
+
123
+ elif model_size == "base_nokd":
124
+ v = timm.create_model("deit_base_patch16_384", pretrained=False)
125
+ hidden_dim = 768
126
+
127
+ else:
128
+ print("Wrong model size!")
129
+ sys.exit(0)
130
+
131
+ return v, hidden_dim
132
+
133
+
134
+ def _trunc_normal_(tensor, mean, std, a, b):
135
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
136
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
137
+ def norm_cdf(x):
138
+ # Computes standard normal cumulative distribution function
139
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
140
+
141
+ # Values are generated by using a truncated uniform distribution and
142
+ # then using the inverse CDF for the normal distribution.
143
+ # Get upper and lower cdf values
144
+ left = norm_cdf((a - mean) / std)
145
+ up = norm_cdf((b - mean) / std)
146
+
147
+ # Uniformly fill tensor with values from [l, u], then translate to
148
+ # [2l-1, 2u-1].
149
+ tensor.uniform_(2 * left - 1, 2 * up - 1)
150
+
151
+ # Use inverse cdf transform for normal distribution to get truncated
152
+ # standard normal
153
+ tensor.erfinv_()
154
+
155
+ # Transform to proper mean, std
156
+ tensor.mul_(std * math.sqrt(2.0))
157
+ tensor.add_(mean)
158
+
159
+ # Clamp to ensure it's in the proper range
160
+ tensor.clamp_(min=a, max=b)
161
+ return tensor
162
+
163
+
164
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
165
+ # type: (Tensor, float, float, float, float) -> Tensor
166
+ r"""Fills the input Tensor with values drawn from a truncated
167
+ normal distribution. The values are effectively drawn from the
168
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
169
+ with values outside :math:`[a, b]` redrawn until they are within
170
+ the bounds. The method used for generating the random values works
171
+ best when :math:`a \leq \text{mean} \leq b`.
172
+
173
+ NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
174
+ applied while sampling the normal with mean/std applied, therefore a, b args
175
+ should be adjusted to match the range of mean, std args.
176
+
177
+ Args:
178
+ tensor: an n-dimensional `torch.Tensor`
179
+ mean: the mean of the normal distribution
180
+ std: the standard deviation of the normal distribution
181
+ a: the minimum cutoff value
182
+ b: the maximum cutoff value
183
+ Examples:
184
+ >>> w = torch.empty(3, 5)
185
+ >>> nn.init.trunc_normal_(w)
186
+ """
187
+ with torch.no_grad():
188
+ return _trunc_normal_(tensor, mean, std, a, b)
189
+
190
+
191
+ def expand_index_like(index: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
192
+ """Expands the index along the last dimension of the input tokens.
193
+
194
+ Args:
195
+ index:
196
+ Index tensor with shape (batch_size, idx_length) where each entry is
197
+ an index in [0, sequence_length).
198
+ tokens:
199
+ Tokens tensor with shape (batch_size, sequence_length, dim).
200
+
201
+ Returns:
202
+ Index tensor with shape (batch_size, idx_length, dim) where the original
203
+ indices are repeated dim times along the last dimension.
204
+
205
+ """
206
+ dim = tokens.shape[-1]
207
+ index = index.unsqueeze(-1).expand(-1, -1, dim)
208
+ return index
209
+
210
+ def set_at_index(
211
+ tokens: torch.Tensor, index: torch.Tensor, value: torch.Tensor
212
+ ) -> torch.Tensor:
213
+ """Copies all values into the input tensor at the given indices.
214
+
215
+ Args:
216
+ tokens:
217
+ Tokens tensor with shape (batch_size, sequence_length, dim).
218
+ index:
219
+ Index tensor with shape (batch_size, index_length).
220
+ value:
221
+ Value tensor with shape (batch_size, index_length, dim).
222
+
223
+ Returns:
224
+ Tokens tensor with shape (batch_size, sequence_length, dim) containing
225
+ the new values.
226
+
227
+ """
228
+ index = expand_index_like(index, tokens)
229
+ return torch.scatter(tokens, 1, index, value)
230
+
231
+
232
+
233
+
234
+ def repeat_token(token: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
235
+ """Repeats a token size times.
236
+
237
+ Args:
238
+ token:
239
+ Token tensor with shape (1, 1, dim).
240
+ size:
241
+ (batch_size, sequence_length) tuple.
242
+
243
+ Returns:
244
+ Tensor with shape (batch_size, sequence_length, dim) containing copies
245
+ of the input token.
246
+
247
+ """
248
+ batch_size, sequence_length = size
249
+ return token.repeat(batch_size, sequence_length, 1)