ASLP-lab commited on
Commit
d0690fd
·
1 Parent(s): 512c889

add one-click func

Browse files
config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SongFormerModel"
4
+ ],
5
+ "model_type": "songformer",
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_songformer.SongFormerConfig",
8
+ "AutoModel": "modeling_songformer.SongFormerModel"
9
+ }
10
+ }
configuration_songformer.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class SongFormerConfig(PretrainedConfig):
4
+ """Configuration class to store the configuration of a custom model."""
5
+
6
+ model_type = "custom_model"
7
+
8
+ def __init__(
9
+ self,
10
+ win_size=420,
11
+ hop_size=420,
12
+ num_classes=128,
13
+ no_rule_post_processing=False,
14
+ local_maxima_filter_size=3,
15
+ frame_rates=8.333,
16
+ **kwargs
17
+ ):
18
+ super().__init__(**kwargs)
19
+ self.win_size = win_size
20
+ self.hop_size = hop_size
21
+ self.num_classes = num_classes
22
+ self.no_rule_post_processing = no_rule_post_processing
23
+ self.local_maxima_filter_size = local_maxima_filter_size
24
+ self.frame_rates = frame_rates
dataset/custom_types.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MsaInfo
3
+ A list of (timestamp, label) tuples used to represent music structure
4
+ analysis (MSA). The first element of the tuple is a float timestamp
5
+ (in seconds) and the second is a string label
6
+
7
+ Example
8
+ -------
9
+ >>> msa: MsaInfo = [(0.0, "intro"), (12.5, "verse"), (34.0, "chorus")]
10
+ """
11
+
12
+ from typing import List, Tuple
13
+
14
+ MsaInfo = List[Tuple[float, str]]
dataset/label2id.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LABEL_TO_ID = {
2
+ "intro": 0,
3
+ "verse": 1,
4
+ "chorus": 2,
5
+ "bridge": 3,
6
+ "inst": 4,
7
+ "outro": 5,
8
+ "silence": 6,
9
+ "intchorus": 7,
10
+ "prechorus": 8,
11
+ "gtrbreak": 9,
12
+ "solo": 10,
13
+ "quietchorus": 11,
14
+ "bre": 12,
15
+ "break": 13,
16
+ "introverse": 14,
17
+ "mainriff": 15,
18
+ "chorushalf": 16,
19
+ "instintro": 17,
20
+ "gtr": 18,
21
+ "vocaloutro": 19,
22
+ "verse_slow": 20,
23
+ "fadein": 21,
24
+ "saxobeat": 22,
25
+ "transition": 23,
26
+ "verse1a": 24,
27
+ "build": 25,
28
+ "pre-chorus": 26,
29
+ "outroa": 27,
30
+ "bigoutro": 28,
31
+ "fast": 29,
32
+ "instrumentalverse": 30,
33
+ "section": 31,
34
+ "choruspart": 32,
35
+ "instbridge": 33,
36
+ "guitar": 34,
37
+ "instrumental": 35,
38
+ "breakdown": 36,
39
+ "rhythmlessintro": 37,
40
+ "intropt": 38,
41
+ "interlude": 39,
42
+ "postchorus": 40,
43
+ "postverse": 41,
44
+ "opening": 42,
45
+ "altchorus": 43,
46
+ "stutter": 44,
47
+ "oddriff": 45,
48
+ "synth": 46,
49
+ "preverse": 47,
50
+ "quiet": 48,
51
+ "raps": 49,
52
+ "verseinst": 50,
53
+ "instchorus": 51,
54
+ "chorus_instrumental": 52,
55
+ "slowverse": 53,
56
+ "slow": 54,
57
+ "worstthingever": 55,
58
+ "transition2a": 56,
59
+ "miniverse": 57,
60
+ "refrain": 58,
61
+ "introchorus": 59,
62
+ "drumroll": 60,
63
+ "guitarsolo": 61,
64
+ "versepart": 62,
65
+ "chorusinst": 63,
66
+ "ending": 64,
67
+ "no-vocal-intro": 65,
68
+ "no-vocal-interlude": 66,
69
+ "no-vocal-outro": 67,
70
+ "NO_LABEL": 68, # Only referring to cases without labels, this portion of labels will be ignored during the loss calculation process.
71
+ }
72
+
73
+ ID_TO_LABEL = {v: k for k, v in LABEL_TO_ID.items()}
74
+
75
+ # Reserve 64 embedding positions for dataset identifiers in the model.
76
+ DATASET_LABEL_TO_DATASET_ID = {
77
+ "SongForm-HX-7Class": 0, # Categories after rule mapping for HarmonixSet
78
+ "SongForm-HX-Widen": 1, # Original HarmonixSet
79
+ "SongForm-Private-Raw": 2,
80
+ "SongForm-Private": 3,
81
+ "SongForm-HX-Gemini-Relabeled": 4, # Rule-mapped HarmonixSet corrected by Gemini
82
+ "SongForm-HX-8Class": 5, # Rule-mapped (pre-chorus retained)
83
+ "SongForm-Hook": 6,
84
+ "SongForm-Gem": 7,
85
+ "SongForm-Gem-Only-Label": 8, # Use only segments with labels in SongForm-Gem
86
+ }
87
+
88
+ DATASET_ID_TO_DATASET_LABEL = {v: k for k, v in DATASET_LABEL_TO_DATASET_ID.items()}
89
+
90
+ DATASET_ID_ALLOWED_LABEL_IDS = {
91
+ 0: [0, 1, 2, 3, 4, 5, 6],
92
+ 1: [
93
+ 0,
94
+ 1,
95
+ 2,
96
+ 3,
97
+ 4,
98
+ 5,
99
+ 6,
100
+ 7,
101
+ 8,
102
+ 9,
103
+ 10,
104
+ 11,
105
+ 12,
106
+ 13,
107
+ 14,
108
+ 15,
109
+ 16,
110
+ 17,
111
+ 18,
112
+ 19,
113
+ 20,
114
+ 21,
115
+ 22,
116
+ 23,
117
+ 24,
118
+ 25,
119
+ 27,
120
+ 28,
121
+ 29,
122
+ 30,
123
+ 31,
124
+ 32,
125
+ 33,
126
+ 34,
127
+ 35,
128
+ 36,
129
+ 37,
130
+ 38,
131
+ 40,
132
+ 41,
133
+ 42,
134
+ 43,
135
+ 44,
136
+ 45,
137
+ 46,
138
+ 47,
139
+ 48,
140
+ 49,
141
+ 50,
142
+ 51,
143
+ 52,
144
+ 53,
145
+ 54,
146
+ 55,
147
+ 56,
148
+ 57,
149
+ 58,
150
+ 59,
151
+ 60,
152
+ 61,
153
+ 62,
154
+ 63,
155
+ ],
156
+ 2: [0, 1, 2, 3, 26, 39, 64, 65, 66, 67],
157
+ 3: [0, 1, 2, 3, 4, 5, 6, 26, 39, 64, 65, 66, 67],
158
+ 4: [0, 1, 2, 3, 4, 5, 6, 26],
159
+ 5: [0, 1, 2, 3, 4, 5, 6, 26],
160
+ 6: [0, 1, 2, 3, 4, 5, 6, 26],
161
+ 7: [0, 1, 2, 3, 4, 5, 6, 26],
162
+ 8: [0, 1, 2, 3, 4, 5, 6, 26],
163
+ }
model.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import scipy
3
+ import numpy as np
4
+
5
+ scipy.inf = np.inf
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import numpy as np
10
+ import torch.nn.functional as F
11
+ from dataset.custom_types import MsaInfo
12
+ from msaf.eval import compute_results
13
+ from postprocessing.functional import postprocess_functional_structure
14
+ from x_transformers import Encoder
15
+ import bisect
16
+
17
+
18
+ class Head(nn.Module):
19
+ def __init__(self, input_dim, output_dim, hidden_dims=None, activation="silu"):
20
+ super().__init__()
21
+ hidden_dims = hidden_dims or []
22
+ act_layers = {"relu": nn.ReLU, "silu": nn.SiLU, "gelu": nn.GELU}
23
+ act_layer = act_layers.get(activation.lower())
24
+ if not act_layer:
25
+ raise ValueError(f"Unsupported activation: {activation}")
26
+
27
+ dims = [input_dim] + hidden_dims + [output_dim]
28
+ layers = []
29
+ for i in range(len(dims) - 1):
30
+ layers.append(nn.Linear(dims[i], dims[i + 1]))
31
+ if i < len(dims) - 2:
32
+ layers.append(act_layer())
33
+ self.net = nn.Sequential(*layers)
34
+
35
+ def reset_parameters(self, confidence):
36
+ bias_value = -torch.log(torch.tensor((1 - confidence) / confidence))
37
+ self.net[-1].bias.data.fill_(bias_value.item())
38
+
39
+ def forward(self, x):
40
+ batch, T, C = x.shape
41
+ x = x.reshape(-1, C)
42
+ x = self.net(x)
43
+ return x.reshape(batch, T, -1)
44
+
45
+
46
+ class WrapedTransformerEncoder(nn.Module):
47
+ def __init__(
48
+ self, input_dim, transformer_input_dim, num_layers=1, nhead=8, dropout=0.1
49
+ ):
50
+ super().__init__()
51
+ self.input_dim = input_dim
52
+ self.transformer_input_dim = transformer_input_dim
53
+
54
+ if input_dim != transformer_input_dim:
55
+ self.input_proj = nn.Sequential(
56
+ nn.Linear(input_dim, transformer_input_dim),
57
+ nn.LayerNorm(transformer_input_dim),
58
+ nn.GELU(),
59
+ nn.Dropout(dropout * 0.5),
60
+ nn.Linear(transformer_input_dim, transformer_input_dim),
61
+ )
62
+ else:
63
+ self.input_proj = nn.Identity()
64
+
65
+ self.transformer = Encoder(
66
+ dim=transformer_input_dim,
67
+ depth=num_layers,
68
+ heads=nhead,
69
+ layer_dropout=dropout,
70
+ attn_dropout=dropout,
71
+ ff_dropout=dropout,
72
+ attn_flash=True,
73
+ rotary_pos_emb=True,
74
+ )
75
+
76
+ def forward(self, x, src_key_padding_mask=None):
77
+ """
78
+ The input src_key_padding_mask is a B x T boolean mask, where True indicates masked positions.
79
+ However, in x-transformers, False indicates masked positions.
80
+ Therefore, it needs to be converted so that False represents masked positions.
81
+ """
82
+ x = self.input_proj(x)
83
+ mask = (
84
+ ~torch.tensor(src_key_padding_mask, dtype=torch.bool, device=x.device)
85
+ if src_key_padding_mask is not None
86
+ else None
87
+ )
88
+ return self.transformer(x, mask=mask)
89
+
90
+
91
+ def prefix_dict(d, prefix: str):
92
+ if prefix:
93
+ return d
94
+ return {prefix + key: value for key, value in d.items()}
95
+
96
+
97
+ class TimeDownsample(nn.Module):
98
+ def __init__(
99
+ self, dim_in, dim_out=None, kernel_size=5, stride=5, padding=0, dropout=0.1
100
+ ):
101
+ super().__init__()
102
+ self.dim_out = dim_out or dim_in
103
+ assert self.dim_out % 2 == 0
104
+
105
+ self.depthwise_conv = nn.Conv1d(
106
+ in_channels=dim_in,
107
+ out_channels=dim_in,
108
+ kernel_size=kernel_size,
109
+ stride=stride,
110
+ padding=padding,
111
+ groups=dim_in,
112
+ bias=False,
113
+ )
114
+ self.pointwise_conv = nn.Conv1d(
115
+ in_channels=dim_in,
116
+ out_channels=self.dim_out,
117
+ kernel_size=1,
118
+ bias=False,
119
+ )
120
+ self.pool = nn.AvgPool1d(kernel_size, stride, padding=padding)
121
+ self.norm1 = nn.LayerNorm(self.dim_out)
122
+ self.act1 = nn.GELU()
123
+ self.dropout1 = nn.Dropout(dropout)
124
+
125
+ if dim_in != self.dim_out:
126
+ self.residual_conv = nn.Conv1d(
127
+ dim_in, self.dim_out, kernel_size=1, bias=False
128
+ )
129
+ else:
130
+ self.residual_conv = None
131
+
132
+ def forward(self, x):
133
+ residual = x # [B, T, D_in]
134
+ # Convolutional module
135
+ x_c = x.transpose(1, 2) # [B, D_in, T]
136
+ x_c = self.depthwise_conv(x_c) # [B, D_in, T_down]
137
+ x_c = self.pointwise_conv(x_c) # [B, D_out, T_down]
138
+
139
+ # Residual module
140
+ res = self.pool(residual.transpose(1, 2)) # [B, D_in, T]
141
+ if self.residual_conv:
142
+ res = self.residual_conv(res) # [B, D_out, T_down]
143
+ x_c = x_c + res # [B, D_out, T_down]
144
+ x_c = x_c.transpose(1, 2) # [B, T_down, D_out]
145
+ x_c = self.norm1(x_c)
146
+ x_c = self.act1(x_c)
147
+ x_c = self.dropout1(x_c)
148
+ return x_c
149
+
150
+
151
+ class AddFuse(nn.Module):
152
+ def __init__(self):
153
+ super(AddFuse, self).__init__()
154
+
155
+ def forward(self, x, cond):
156
+ return x + cond
157
+
158
+
159
+ class TVLoss1D(nn.Module):
160
+ def __init__(
161
+ self, beta=1.0, lambda_tv=0.4, boundary_threshold=0.01, reduction_weight=0.1
162
+ ):
163
+ """
164
+ Args:
165
+ beta: Exponential parameter for TV loss (recommended 0.5~1.0)
166
+ lambda_tv: Overall weight for TV loss
167
+ boundary_threshold: Label threshold to determine if a region is a "boundary area" (e.g., 0.01)
168
+ reduction_weight: Scaling factor for TV penalty within boundary regions (e.g., 0.1, meaning only 10% penalty)
169
+ """
170
+ super().__init__()
171
+ self.beta = beta
172
+ self.lambda_tv = lambda_tv
173
+ self.boundary_threshold = boundary_threshold
174
+ self.reduction_weight = reduction_weight
175
+
176
+ def forward(self, pred, target=None):
177
+ """
178
+ Args:
179
+ pred: (B, T) or (B, T, 1), float boundary scores output by the model
180
+ target: (B, T) or (B, T, 1), ground truth labels (optional, used for spatial weighting if provided)
181
+
182
+ Returns:
183
+ scalar: weighted TV loss
184
+ """
185
+ if pred.dim() == 3:
186
+ pred = pred.squeeze(-1)
187
+ if target is not None and target.dim() == 3:
188
+ target = target.squeeze(-1)
189
+
190
+ diff = pred[:, 1:] - pred[:, :-1]
191
+ tv_base = torch.pow(torch.abs(diff) + 1e-8, self.beta)
192
+
193
+ if target is None:
194
+ return self.lambda_tv * tv_base.mean()
195
+
196
+ left_in_boundary = target[:, :-1] > self.boundary_threshold
197
+ right_in_boundary = target[:, 1:] > self.boundary_threshold
198
+ near_boundary = left_in_boundary | right_in_boundary
199
+ weight_mask = torch.where(
200
+ near_boundary,
201
+ self.reduction_weight * torch.ones_like(tv_base),
202
+ torch.ones_like(tv_base),
203
+ )
204
+ tv_weighted = (tv_base * weight_mask).mean()
205
+ return self.lambda_tv * tv_weighted
206
+
207
+
208
+ class SoftmaxFocalLoss(nn.Module):
209
+ """
210
+ Softmax Focal Loss for single-label multi-class classification.
211
+ Suitable for mutually exclusive classes.
212
+ """
213
+
214
+ def __init__(self, alpha: float = 0.25, gamma: float = 2.0):
215
+ super().__init__()
216
+ self.alpha = alpha
217
+ self.gamma = gamma
218
+
219
+ def forward(self, pred: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
220
+ """
221
+ Args:
222
+ pred: [B, T, C], raw logits
223
+ targets: [B, T, C] (soft) or [B, T] (hard, dtype=long)
224
+ Returns:
225
+ loss: scalar or [B, T] depending on reduction
226
+ """
227
+ log_probs = F.log_softmax(pred, dim=-1)
228
+ probs = torch.exp(log_probs)
229
+
230
+ if targets.dtype == torch.long:
231
+ targets_onehot = F.one_hot(targets, num_classes=pred.size(-1)).float()
232
+ else:
233
+ targets_onehot = targets
234
+
235
+ p_t = (probs * targets_onehot).sum(dim=-1)
236
+ p_t = p_t.clamp(min=1e-8, max=1.0 - 1e-8)
237
+
238
+ if self.alpha > 0:
239
+ alpha_t = self.alpha * targets_onehot + (1 - self.alpha) * (
240
+ 1 - targets_onehot
241
+ )
242
+ alpha_weight = (alpha_t * targets_onehot).sum(dim=-1)
243
+ else:
244
+ alpha_weight = 1.0
245
+
246
+ focal_weight = (1 - p_t) ** self.gamma
247
+ ce_loss = -log_probs * targets_onehot
248
+ ce_loss = ce_loss.sum(dim=-1)
249
+
250
+ loss = alpha_weight * focal_weight * ce_loss
251
+ return loss
252
+
253
+
254
+ class Model(nn.Module):
255
+ def __init__(self, config):
256
+ super().__init__()
257
+ self.config = config
258
+
259
+ self.input_norm = nn.LayerNorm(config.input_dim)
260
+ self.mixed_win_downsample = nn.Linear(config.input_dim_raw, config.input_dim)
261
+ self.dataset_class_prefix = nn.Embedding(
262
+ num_embeddings=config.num_dataset_classes,
263
+ embedding_dim=config.transformer_encoder_input_dim,
264
+ )
265
+ self.down_sample_conv = TimeDownsample(
266
+ dim_in=config.input_dim,
267
+ dim_out=config.transformer_encoder_input_dim,
268
+ kernel_size=config.down_sample_conv_kernel_size,
269
+ stride=config.down_sample_conv_stride,
270
+ dropout=config.down_sample_conv_dropout,
271
+ padding=config.down_sample_conv_padding,
272
+ )
273
+ self.AddFuse = AddFuse()
274
+ self.transformer = WrapedTransformerEncoder(
275
+ input_dim=config.transformer_encoder_input_dim,
276
+ transformer_input_dim=config.transformer_input_dim,
277
+ num_layers=config.num_transformer_layers,
278
+ nhead=config.transformer_nhead,
279
+ dropout=config.transformer_dropout,
280
+ )
281
+ self.boundary_TVLoss1D = TVLoss1D(
282
+ beta=config.boundary_tv_loss_beta,
283
+ lambda_tv=config.boundary_tv_loss_lambda,
284
+ boundary_threshold=config.boundary_tv_loss_boundary_threshold,
285
+ reduction_weight=config.boundary_tv_loss_reduction_weight,
286
+ )
287
+ self.label_focal_loss = SoftmaxFocalLoss(
288
+ alpha=config.label_focal_loss_alpha, gamma=config.label_focal_loss_gamma
289
+ )
290
+ self.boundary_head = Head(config.transformer_input_dim, 1)
291
+ self.function_head = Head(config.transformer_input_dim, config.num_classes)
292
+
293
+ def cal_metrics(self, gt_info: MsaInfo, msa_info: MsaInfo):
294
+ assert gt_info[-1][1] == "end" and msa_info[-1][1] == "end", (
295
+ "gt_info and msa_info should end with 'end'"
296
+ )
297
+ gt_info_labels = [label for time_, label in gt_info][:-1]
298
+ gt_info_inters = [time_ for time_, label in gt_info]
299
+ gt_info_inters = np.column_stack(
300
+ [np.array(gt_info_inters[:-1]), np.array(gt_info_inters[1:])]
301
+ )
302
+
303
+ msa_info_labels = [label for time_, label in msa_info][:-1]
304
+ msa_info_inters = [time_ for time_, label in msa_info]
305
+ msa_info_inters = np.column_stack(
306
+ [np.array(msa_info_inters[:-1]), np.array(msa_info_inters[1:])]
307
+ )
308
+ result = compute_results(
309
+ ann_inter=gt_info_inters,
310
+ est_inter=msa_info_inters,
311
+ ann_labels=gt_info_labels,
312
+ est_labels=msa_info_labels,
313
+ bins=11,
314
+ est_file="test.txt",
315
+ weight=0.58,
316
+ )
317
+ return result
318
+
319
+ def cal_acc(
320
+ self, ann_info: MsaInfo | str, est_info: MsaInfo | str, post_digit: int = 3
321
+ ):
322
+ ann_info_time = [
323
+ int(round(time_, post_digit) * (10**post_digit))
324
+ for time_, label in ann_info
325
+ ]
326
+ est_info_time = [
327
+ int(round(time_, post_digit) * (10**post_digit))
328
+ for time_, label in est_info
329
+ ]
330
+
331
+ common_start_time = max(ann_info_time[0], est_info_time[0])
332
+ common_end_time = min(ann_info_time[-1], est_info_time[-1])
333
+
334
+ time_points = {common_start_time, common_end_time}
335
+ time_points.update(
336
+ {
337
+ time_
338
+ for time_ in ann_info_time
339
+ if common_start_time <= time_ <= common_end_time
340
+ }
341
+ )
342
+ time_points.update(
343
+ {
344
+ time_
345
+ for time_ in est_info_time
346
+ if common_start_time <= time_ <= common_end_time
347
+ }
348
+ )
349
+
350
+ time_points = sorted(time_points)
351
+ total_duration, total_score = 0, 0
352
+
353
+ for idx in range(len(time_points) - 1):
354
+ duration = time_points[idx + 1] - time_points[idx]
355
+ ann_label = ann_info[
356
+ bisect.bisect_right(ann_info_time, time_points[idx]) - 1
357
+ ][1]
358
+ est_label = est_info[
359
+ bisect.bisect_right(est_info_time, time_points[idx]) - 1
360
+ ][1]
361
+ total_duration += duration
362
+ if ann_label == est_label:
363
+ total_score += duration
364
+ return total_score / total_duration
365
+
366
+ def infer_with_metrics(self, batch, prefix: str = None):
367
+ with torch.no_grad():
368
+ logits = self.forward_func(batch)
369
+
370
+ losses = self.compute_losses(logits, batch, prefix=None)
371
+
372
+ expanded_mask = batch["label_id_masks"].expand(
373
+ -1, logits["function_logits"].size(1), -1
374
+ )
375
+ logits["function_logits"] = logits["function_logits"].masked_fill(
376
+ expanded_mask, -float("inf")
377
+ )
378
+
379
+ msa_info = postprocess_functional_structure(
380
+ logits=logits, config=self.config
381
+ )
382
+ gt_info = batch["msa_infos"][0]
383
+ results = self.cal_metrics(gt_info=gt_info, msa_info=msa_info)
384
+
385
+ ret_results = {
386
+ "loss": losses["loss"].item(),
387
+ "HitRate_3P": results["HitRate_3P"],
388
+ "HitRate_3R": results["HitRate_3R"],
389
+ "HitRate_3F": results["HitRate_3F"],
390
+ "HitRate_0.5P": results["HitRate_0.5P"],
391
+ "HitRate_0.5R": results["HitRate_0.5R"],
392
+ "HitRate_0.5F": results["HitRate_0.5F"],
393
+ "PWF": results["PWF"],
394
+ "PWP": results["PWP"],
395
+ "PWR": results["PWR"],
396
+ "Sf": results["Sf"],
397
+ "So": results["So"],
398
+ "Su": results["Su"],
399
+ "acc": self.cal_acc(ann_info=gt_info, est_info=msa_info),
400
+ }
401
+ if prefix:
402
+ ret_results = prefix_dict(ret_results, prefix)
403
+
404
+ return ret_results
405
+
406
+ def infer(
407
+ self,
408
+ input_embeddings,
409
+ dataset_ids,
410
+ label_id_masks,
411
+ prefix: str = None,
412
+ with_logits=False,
413
+ ):
414
+ with torch.no_grad():
415
+ input_embeddings = self.mixed_win_downsample(input_embeddings)
416
+ input_embeddings = self.input_norm(input_embeddings)
417
+ logits = self.down_sample_conv(input_embeddings)
418
+
419
+ dataset_prefix = self.dataset_class_prefix(dataset_ids)
420
+ dataset_prefix_expand = dataset_prefix.unsqueeze(1).expand(
421
+ logits.size(0), 1, -1
422
+ )
423
+ logits = self.AddFuse(x=logits, cond=dataset_prefix_expand)
424
+ logits = self.transformer(x=logits, src_key_padding_mask=None)
425
+
426
+ function_logits = self.function_head(logits)
427
+ boundary_logits = self.boundary_head(logits).squeeze(-1)
428
+
429
+ logits = {
430
+ "function_logits": function_logits,
431
+ "boundary_logits": boundary_logits,
432
+ }
433
+
434
+ expanded_mask = label_id_masks.expand(
435
+ -1, logits["function_logits"].size(1), -1
436
+ )
437
+ logits["function_logits"] = logits["function_logits"].masked_fill(
438
+ expanded_mask, -float("inf")
439
+ )
440
+
441
+ msa_info = postprocess_functional_structure(
442
+ logits=logits, config=self.config
443
+ )
444
+
445
+ return (msa_info, logits) if with_logits else msa_info
446
+
447
+ def compute_losses(self, outputs, batch, prefix: str = None):
448
+ loss = 0.0
449
+ losses = {}
450
+
451
+ loss_section = F.binary_cross_entropy_with_logits(
452
+ outputs["boundary_logits"],
453
+ batch["widen_true_boundaries"],
454
+ reduction="none",
455
+ )
456
+ loss_section += self.config.boundary_tvloss_weight * self.boundary_TVLoss1D(
457
+ pred=outputs["boundary_logits"],
458
+ target=batch["widen_true_boundaries"],
459
+ )
460
+ loss_function = F.cross_entropy(
461
+ outputs["function_logits"].transpose(1, 2),
462
+ batch["true_functions"].transpose(1, 2),
463
+ reduction="none",
464
+ )
465
+ # input is [B, T, C]
466
+ ttt = self.config.label_focal_loss_weight * self.label_focal_loss(
467
+ pred=outputs["function_logits"], targets=batch["true_functions"]
468
+ )
469
+ loss_function += ttt
470
+
471
+ float_masks = (~batch["masks"]).float()
472
+ boundary_mask = batch.get("boundary_mask", None)
473
+ function_mask = batch.get("function_mask", None)
474
+ if boundary_mask is not None:
475
+ boundary_mask = (~boundary_mask).float()
476
+ else:
477
+ boundary_mask = 1
478
+
479
+ if function_mask is not None:
480
+ function_mask = (~function_mask).float()
481
+ else:
482
+ function_mask = 1
483
+
484
+ loss_section = torch.mean(boundary_mask * float_masks * loss_section)
485
+ loss_function = torch.mean(function_mask * float_masks * loss_function)
486
+
487
+ loss_section *= self.config.loss_weight_section
488
+ loss_function *= self.config.loss_weight_function
489
+
490
+ if self.config.learn_label:
491
+ loss += loss_function
492
+ if self.config.learn_segment:
493
+ loss += loss_section
494
+
495
+ losses.update(
496
+ loss=loss,
497
+ loss_section=loss_section,
498
+ loss_function=loss_function,
499
+ )
500
+ if prefix:
501
+ losses = prefix_dict(losses, prefix)
502
+ return losses
503
+
504
+ def forward_func(self, batch):
505
+ input_embeddings = batch["input_embeddings"]
506
+ input_embeddings = self.mixed_win_downsample(input_embeddings)
507
+ input_embeddings = self.input_norm(input_embeddings)
508
+ logits = self.down_sample_conv(input_embeddings)
509
+
510
+ dataset_prefix = self.dataset_class_prefix(batch["dataset_ids"])
511
+ logits = self.AddFuse(x=logits, cond=dataset_prefix.unsqueeze(1))
512
+ src_key_padding_mask = batch["masks"]
513
+ logits = self.transformer(x=logits, src_key_padding_mask=src_key_padding_mask)
514
+
515
+ function_logits = self.function_head(logits)
516
+ boundary_logits = self.boundary_head(logits).squeeze(-1)
517
+
518
+ logits = {
519
+ "function_logits": function_logits,
520
+ "boundary_logits": boundary_logits,
521
+ }
522
+ return logits
523
+
524
+ def forward(self, batch):
525
+ logits = self.forward_func(batch)
526
+ losses = self.compute_losses(logits, batch, prefix=None)
527
+ return logits, losses["loss"], losses
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8dcabf4ea19973edd51b9e5794004775fa7e8de3ecfa07eb1dbce00f516ce7f7
3
+ size 2755035132
model_config.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+ from transformers import PretrainedConfig
3
+
4
+ class ModelConfig(PretrainedConfig):
5
+ model_type = "SongFormer"
6
+
7
+ def __init__(
8
+ self,
9
+ input_dim=2048,
10
+ input_dim_raw=4096,
11
+ transformer_encoder_input_dim=1024,
12
+ transformer_input_dim=512,
13
+ num_transformer_layers=4,
14
+ transformer_nhead=8,
15
+ transformer_dropout=0.1,
16
+ num_classes=128,
17
+ num_dataset_classes=64,
18
+ down_sample_conv_kernel_size=3,
19
+ down_sample_conv_stride=3,
20
+ down_sample_conv_dropout=0.1,
21
+ down_sample_conv_padding=0,
22
+ boundary_tv_loss_beta=0.6,
23
+ boundary_tv_loss_lambda=0.4,
24
+ boundary_tv_loss_boundary_threshold=0.01,
25
+ boundary_tv_loss_reduction_weight=0.1,
26
+ boundary_tvloss_weight=0.05,
27
+ label_focal_loss_alpha=0.25,
28
+ label_focal_loss_gamma=2.0,
29
+ label_focal_loss_weight=0.2,
30
+ loss_weight_section=0.2,
31
+ loss_weight_function=0.8,
32
+ learn_label=True,
33
+ learn_segment=True,
34
+ local_maxima_filter_size=3,
35
+ frame_rates=8.333,
36
+ **kwargs
37
+ ):
38
+ super().__init__(**kwargs)
39
+ self.input_dim = input_dim
40
+ self.input_dim_raw = input_dim_raw
41
+ self.transformer_encoder_input_dim = transformer_encoder_input_dim
42
+ self.transformer_input_dim = transformer_input_dim
43
+ self.num_transformer_layers = num_transformer_layers
44
+ self.transformer_nhead = transformer_nhead
45
+ self.transformer_dropout = transformer_dropout
46
+ self.num_classes = num_classes
47
+ self.num_dataset_classes = num_dataset_classes
48
+ self.down_sample_conv_kernel_size = down_sample_conv_kernel_size
49
+ self.down_sample_conv_stride = down_sample_conv_stride
50
+ self.down_sample_conv_dropout = down_sample_conv_dropout
51
+ self.down_sample_conv_padding = down_sample_conv_padding
52
+ self.boundary_tv_loss_beta = boundary_tv_loss_beta
53
+ self.boundary_tv_loss_lambda = boundary_tv_loss_lambda
54
+ self.boundary_tv_loss_boundary_threshold = boundary_tv_loss_boundary_threshold
55
+ self.boundary_tv_loss_reduction_weight = boundary_tv_loss_reduction_weight
56
+ self.boundary_tvloss_weight = boundary_tvloss_weight
57
+ self.label_focal_loss_alpha = label_focal_loss_alpha
58
+ self.label_focal_loss_gamma = label_focal_loss_gamma
59
+ self.label_focal_loss_weight = label_focal_loss_weight
60
+ self.loss_weight_section = loss_weight_section
61
+ self.loss_weight_function = loss_weight_function
62
+ self.learn_label = learn_label
63
+ self.learn_segment = learn_segment
64
+ self.local_maxima_filter_size = local_maxima_filter_size
65
+ self.frame_rates = frame_rates
modeling_songformer.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ from typing import Tuple
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import PreTrainedModel
6
+ import argparse
7
+ import importlib
8
+ import json
9
+ import math
10
+ import multiprocessing as mp
11
+ import os
12
+ import time
13
+ from argparse import Namespace
14
+ from pathlib import Path
15
+
16
+ # monkey patch to fix issues in msaf
17
+ import scipy
18
+ import numpy as np
19
+
20
+ scipy.inf = np.inf
21
+
22
+ import librosa
23
+ import torch
24
+ from ema_pytorch import EMA
25
+ from loguru import logger
26
+ from muq import MuQ
27
+ from musicfm.model.musicfm_25hz import MusicFM25Hz
28
+ from omegaconf import OmegaConf
29
+ from tqdm import tqdm
30
+ import torch
31
+ import torch.nn as nn
32
+ from transformers import PreTrainedModel
33
+ from transformers.modeling_outputs import CausalLMOutputWithPast
34
+ from configuration_songformer import SongFormerConfig
35
+ from model_config import ModelConfig
36
+
37
+ from model import Model
38
+ from omegaconf import OmegaConf
39
+
40
+ # MUSICFM_HOME_PATH = os.path.join("ckpts", "MusicFM")
41
+ MUSICFM_HOME_PATH = "/home/node59_tmpdata3/cbhao/SongFormer_kaiyuan_test/github_test/SongFormer/src/SongFormer/ckpts/MusicFM"
42
+
43
+ BEFORE_DOWNSAMPLING_FRAME_RATES = 25
44
+ AFTER_DOWNSAMPLING_FRAME_RATES = 8.333
45
+
46
+ DATASET_LABEL = "SongForm-HX-8Class"
47
+ DATASET_IDS = [5]
48
+
49
+ TIME_DUR = 420
50
+ INPUT_SAMPLING_RATE = 24000
51
+
52
+ from dataset.label2id import DATASET_ID_ALLOWED_LABEL_IDS, DATASET_LABEL_TO_DATASET_ID
53
+ from postprocessing.functional import postprocess_functional_structure
54
+
55
+
56
+ def rule_post_processing(msa_list):
57
+ if len(msa_list) <= 2:
58
+ return msa_list
59
+
60
+ result = msa_list.copy()
61
+
62
+ while len(result) > 2:
63
+ first_duration = result[1][0] - result[0][0]
64
+ if first_duration < 1.0 and len(result) > 2:
65
+ result[0] = (result[0][0], result[1][1])
66
+ result = [result[0]] + result[2:]
67
+ else:
68
+ break
69
+
70
+ while len(result) > 2:
71
+ last_label_duration = result[-1][0] - result[-2][0]
72
+ if last_label_duration < 1.0:
73
+ result = result[:-2] + [result[-1]]
74
+ else:
75
+ break
76
+
77
+ while len(result) > 2:
78
+ if result[0][1] == result[1][1] and result[1][0] <= 10.0:
79
+ result = [(result[0][0], result[0][1])] + result[2:]
80
+ else:
81
+ break
82
+
83
+ while len(result) > 2:
84
+ last_duration = result[-1][0] - result[-2][0]
85
+ if result[-2][1] == result[-3][1] and last_duration <= 10.0:
86
+ result = result[:-2] + [result[-1]]
87
+ else:
88
+ break
89
+
90
+ return result
91
+
92
+
93
+ class SongFormerModel(PreTrainedModel):
94
+ config_class = SongFormerConfig
95
+
96
+ def __init__(self, config: SongFormerConfig):
97
+ super().__init__(config)
98
+ device = "cpu"
99
+
100
+ with open("muq_config2.json", "r") as f:
101
+ muq_config_file = OmegaConf.load(f)
102
+ # self.muq = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter", device_map=None)
103
+ self.muq = MuQ(muq_config_file)
104
+
105
+ self.musicfm = MusicFM25Hz(
106
+ is_flash=False,
107
+ stat_path="msd_stats.json",
108
+ # model_path=os.path.join(MUSICFM_HOME_PATH, "pretrained_msd.pt"),
109
+ )
110
+ self.songformer = Model(ModelConfig())
111
+
112
+ num_classes = config.num_classes
113
+ dataset_id2label_mask = {}
114
+ for key, allowed_ids in DATASET_ID_ALLOWED_LABEL_IDS.items():
115
+ dataset_id2label_mask[key] = np.ones(config.num_classes, dtype=bool)
116
+ dataset_id2label_mask[key][allowed_ids] = False
117
+
118
+ self.num_classes = num_classes
119
+ self.dataset_id2label_mask = dataset_id2label_mask
120
+ self.config = config
121
+
122
+ def forward(self, input):
123
+ with torch.no_grad():
124
+ INPUT_SAMPLING_RATE = 24000
125
+
126
+ device = next(self.parameters()).device
127
+ # 如果为tensor或者是numpy
128
+ if isinstance(input, (torch.Tensor, np.ndarray)):
129
+ audio = torch.tensor(input).to(device)
130
+ elif os.path.exists(input):
131
+ wav, sr = librosa.load(input, sr=INPUT_SAMPLING_RATE)
132
+ audio = torch.tensor(wav).to(device)
133
+ else:
134
+ raise ValueError("input should be a tensor/numpy or a valid file path")
135
+
136
+ win_size = self.config.win_size
137
+ hop_size = self.config.hop_size
138
+ num_classes = self.config.num_classes
139
+ total_len = (
140
+ (audio.shape[0] // INPUT_SAMPLING_RATE) // TIME_DUR
141
+ ) * TIME_DUR + TIME_DUR
142
+ total_frames = math.ceil(total_len * AFTER_DOWNSAMPLING_FRAME_RATES)
143
+
144
+ logits = {
145
+ "function_logits": np.zeros([total_frames, num_classes]),
146
+ "boundary_logits": np.zeros([total_frames]),
147
+ }
148
+ logits_num = {
149
+ "function_logits": np.zeros([total_frames, num_classes]),
150
+ "boundary_logits": np.zeros([total_frames]),
151
+ }
152
+
153
+ lens = 0
154
+ i = 0
155
+ while True:
156
+ start_idx = i * INPUT_SAMPLING_RATE
157
+ end_idx = min((i + win_size) * INPUT_SAMPLING_RATE, audio.shape[-1])
158
+ if start_idx >= audio.shape[-1]:
159
+ break
160
+ if end_idx - start_idx <= 1024:
161
+ continue
162
+ audio_seg = audio[start_idx:end_idx]
163
+
164
+ # MuQ embedding
165
+ muq_output = self.muq(audio_seg.unsqueeze(0), output_hidden_states=True)
166
+ muq_embd_420s = muq_output["hidden_states"][10]
167
+ del muq_output
168
+ torch.cuda.empty_cache()
169
+
170
+ # MusicFM embedding
171
+ _, musicfm_hidden_states = self.musicfm.get_predictions(
172
+ audio_seg.unsqueeze(0)
173
+ )
174
+ musicfm_embd_420s = musicfm_hidden_states[10]
175
+ del musicfm_hidden_states
176
+ torch.cuda.empty_cache()
177
+
178
+ wraped_muq_embd_30s = []
179
+ wraped_musicfm_embd_30s = []
180
+
181
+ for idx_30s in range(i, i + hop_size, 30):
182
+ start_idx_30s = idx_30s * INPUT_SAMPLING_RATE
183
+ end_idx_30s = min(
184
+ (idx_30s + 30) * INPUT_SAMPLING_RATE,
185
+ audio.shape[-1],
186
+ (i + hop_size) * INPUT_SAMPLING_RATE,
187
+ )
188
+ if start_idx_30s >= audio.shape[-1]:
189
+ break
190
+ if end_idx_30s - start_idx_30s <= 1024:
191
+ continue
192
+ wraped_muq_embd_30s.append(
193
+ self.muq(
194
+ audio[start_idx_30s:end_idx_30s].unsqueeze(0),
195
+ output_hidden_states=True,
196
+ )["hidden_states"][10]
197
+ )
198
+ torch.cuda.empty_cache()
199
+ wraped_musicfm_embd_30s.append(
200
+ self.musicfm.get_predictions(
201
+ audio[start_idx_30s:end_idx_30s].unsqueeze(0)
202
+ )[1][10]
203
+ )
204
+ torch.cuda.empty_cache()
205
+
206
+ wraped_muq_embd_30s = torch.concatenate(wraped_muq_embd_30s, dim=1)
207
+ wraped_musicfm_embd_30s = torch.concatenate(
208
+ wraped_musicfm_embd_30s, dim=1
209
+ )
210
+ all_embds = [
211
+ wraped_musicfm_embd_30s,
212
+ wraped_muq_embd_30s,
213
+ musicfm_embd_420s,
214
+ muq_embd_420s,
215
+ ]
216
+
217
+ if len(all_embds) > 1:
218
+ embd_lens = [x.shape[1] for x in all_embds]
219
+ max_embd_len = max(embd_lens)
220
+ min_embd_len = min(embd_lens)
221
+ if abs(max_embd_len - min_embd_len) > 4:
222
+ raise ValueError(
223
+ f"Embedding shapes differ too much: {max_embd_len} vs {min_embd_len}"
224
+ )
225
+
226
+ for idx in range(len(all_embds)):
227
+ all_embds[idx] = all_embds[idx][:, :min_embd_len, :]
228
+
229
+ embd = torch.concatenate(all_embds, axis=-1)
230
+
231
+ dataset_label = DATASET_LABEL
232
+ dataset_ids = torch.Tensor(DATASET_IDS).to(device, dtype=torch.long)
233
+ msa_info, chunk_logits = self.songformer.infer(
234
+ input_embeddings=embd,
235
+ dataset_ids=dataset_ids,
236
+ label_id_masks=torch.Tensor(
237
+ self.dataset_id2label_mask[
238
+ DATASET_LABEL_TO_DATASET_ID[dataset_label]
239
+ ]
240
+ )
241
+ .to(device, dtype=bool)
242
+ .unsqueeze(0)
243
+ .unsqueeze(0),
244
+ with_logits=True,
245
+ )
246
+
247
+ start_frame = int(i * AFTER_DOWNSAMPLING_FRAME_RATES)
248
+ end_frame = start_frame + min(
249
+ math.ceil(hop_size * AFTER_DOWNSAMPLING_FRAME_RATES),
250
+ chunk_logits["boundary_logits"][0].shape[0],
251
+ )
252
+
253
+ logits["function_logits"][start_frame:end_frame, :] += (
254
+ chunk_logits["function_logits"][0].detach().cpu().numpy()
255
+ )
256
+ logits["boundary_logits"][start_frame:end_frame] = (
257
+ chunk_logits["boundary_logits"][0].detach().cpu().numpy()
258
+ )
259
+ logits_num["function_logits"][start_frame:end_frame, :] += 1
260
+ logits_num["boundary_logits"][start_frame:end_frame] += 1
261
+ lens += end_frame - start_frame
262
+
263
+ i += hop_size
264
+ logits["function_logits"] /= logits_num["function_logits"]
265
+ logits["boundary_logits"] /= logits_num["boundary_logits"]
266
+
267
+ logits["function_logits"] = torch.from_numpy(
268
+ logits["function_logits"][:lens]
269
+ ).unsqueeze(0)
270
+ logits["boundary_logits"] = torch.from_numpy(
271
+ logits["boundary_logits"][:lens]
272
+ ).unsqueeze(0)
273
+
274
+ msa_infer_output = postprocess_functional_structure(logits, self.config)
275
+
276
+ assert msa_infer_output[-1][-1] == "end"
277
+ if not self.config.no_rule_post_processing:
278
+ msa_infer_output = rule_post_processing(msa_infer_output)
279
+ msa_json = []
280
+ for idx in range(len(msa_infer_output) - 1):
281
+ msa_json.append(
282
+ {
283
+ "label": msa_infer_output[idx][1],
284
+ "start": msa_infer_output[idx][0],
285
+ "end": msa_infer_output[idx + 1][0],
286
+ }
287
+ )
288
+ return msa_json
289
+
290
+ @staticmethod
291
+ def _fix_state_dict_key_on_load(key: str) -> Tuple[str, bool]:
292
+ """Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""
293
+
294
+ # ---- begin: ignore muq ----
295
+ if key.startswith("muq."):
296
+ return key, False
297
+ # ---- end ---
298
+
299
+ # Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert)
300
+ # This rename is logged.
301
+ if key.endswith("LayerNorm.beta"):
302
+ return key.replace("LayerNorm.beta", "LayerNorm.bias"), True
303
+ if key.endswith("LayerNorm.gamma"):
304
+ return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True
305
+
306
+ # Rename weight norm parametrizations to match changes across torch versions.
307
+ # Impacts a number of speech/wav2vec models. e.g. Hubert, Wav2Vec2, and others.
308
+ # This rename is not logged.
309
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
310
+ if key.endswith("weight_g"):
311
+ return key.replace(
312
+ "weight_g", "parametrizations.weight.original0"
313
+ ), True
314
+ if key.endswith("weight_v"):
315
+ return key.replace(
316
+ "weight_v", "parametrizations.weight.original1"
317
+ ), True
318
+ else:
319
+ if key.endswith("parametrizations.weight.original0"):
320
+ return key.replace(
321
+ "parametrizations.weight.original0", "weight_g"
322
+ ), True
323
+ if key.endswith("parametrizations.weight.original1"):
324
+ return key.replace(
325
+ "parametrizations.weight.original1", "weight_v"
326
+ ), True
327
+
328
+ return key, False
msd_stats.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "spec_256_cnt": 14394344256,
3
+ "spec_256_mean": -23.34296658431829,
4
+ "spec_256_std": 26.189295587132637,
5
+ "spec_512_cnt": 28677104448,
6
+ "spec_512_mean": -21.31267396860235,
7
+ "spec_512_std": 26.52644536245769,
8
+ "spec_1024_cnt": 57242624832,
9
+ "spec_1024_mean": -18.852271129208273,
10
+ "spec_1024_std": 26.443154583585663,
11
+ "spec_2048_cnt": 114373665600,
12
+ "spec_2048_mean": -15.638743433896792,
13
+ "spec_2048_std": 26.115825961611545,
14
+ "spec_4096_cnt": 228635747136,
15
+ "spec_4096_mean": -11.715532502794836,
16
+ "spec_4096_std": 25.763972210234062,
17
+ "melspec_256_cnt": 14282760192,
18
+ "melspec_256_mean": -26.962600400166156,
19
+ "melspec_256_std": 36.13614100912126,
20
+ "melspec_512_cnt": 14282760192,
21
+ "melspec_512_mean": -9.108344167718862,
22
+ "melspec_512_std": 24.71910937988429,
23
+ "melspec_1024_cnt": 14282760192,
24
+ "melspec_1024_mean": 0.37302579246531126,
25
+ "melspec_1024_std": 18.684082325919388,
26
+ "melspec_2048_cnt": 14282760192,
27
+ "melspec_2048_mean": 6.768444971712967,
28
+ "melspec_2048_std": 18.417922652295623,
29
+ "melspec_4096_cnt": 14282760192,
30
+ "melspec_4096_mean": 13.617164614990036,
31
+ "melspec_4096_std": 18.08552130124525,
32
+ "cqt_cnt": 9373061376,
33
+ "cqt_mean": 0.46341379757927165,
34
+ "cqt_std": 0.9543998080910191,
35
+ "mfcc_256_cnt": 1339008768,
36
+ "mfcc_256_mean": -11.681755459447485,
37
+ "mfcc_256_std": 29.183186444668316,
38
+ "mfcc_512_cnt": 1339008768,
39
+ "mfcc_512_mean": -2.540581461792183,
40
+ "mfcc_512_std": 31.93752185832081,
41
+ "mfcc_1024_cnt": 1339008768,
42
+ "mfcc_1024_mean": 6.606636263169779,
43
+ "mfcc_1024_std": 34.151644801729624,
44
+ "mfcc_2048_cnt": 1339008768,
45
+ "mfcc_2048_mean": 5.281600844245184,
46
+ "mfcc_2048_std": 33.12784541220003,
47
+ "mfcc_4096_cnt": 1339008768,
48
+ "mfcc_4096_mean": 4.7616569480166095,
49
+ "mfcc_4096_std": 32.61458906894133,
50
+ "chromagram_256_cnt": 1339008768,
51
+ "chromagram_256_mean": 55.15596556703181,
52
+ "chromagram_256_std": 73.91858278719991,
53
+ "chromagram_512_cnt": 1339008768,
54
+ "chromagram_512_mean": 175.73092252759895,
55
+ "chromagram_512_std": 248.48485148525953,
56
+ "chromagram_1024_cnt": 1339008768,
57
+ "chromagram_1024_mean": 589.2947481634608,
58
+ "chromagram_1024_std": 913.857929063196,
59
+ "chromagram_2048_cnt": 1339008768,
60
+ "chromagram_2048_mean": 2062.286388327397,
61
+ "chromagram_2048_std": 3458.92657915397,
62
+ "chromagram_4096_cnt": 1339008768,
63
+ "chromagram_4096_mean": 7673.039107997085,
64
+ "chromagram_4096_std": 13009.883158267234
65
+ }
muq_config2.json ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "label_rate": 25,
3
+ "num_codebooks": 1,
4
+ "codebook_dim": 16,
5
+ "codebook_size": 8192,
6
+ "features": [
7
+ "melspec_2048"
8
+ ],
9
+ "hop_length": 240,
10
+ "n_mels": 128,
11
+ "conv_dim": 512,
12
+ "encoder_dim": 1024,
13
+ "encoder_depth": 12,
14
+ "mask_hop": 0.4,
15
+ "mask_prob": 0.6,
16
+ "is_flash": false,
17
+ "stat": {
18
+ "melspec_2048_cnt": 14282760192,
19
+ "melspec_2048_mean": 6.768444971712967,
20
+ "melspec_2048_std": 18.417922652295623
21
+ },
22
+ "w2v2_config": {
23
+ "activation_dropout": 0.1,
24
+ "adapter_kernel_size": 3,
25
+ "adapter_stride": 2,
26
+ "add_adapter": false,
27
+ "apply_spec_augment": true,
28
+ "architectures": [
29
+ "Wav2Vec2ConformerForCTC"
30
+ ],
31
+ "attention_dropout": 0.1,
32
+ "bos_token_id": 1,
33
+ "classifier_proj_size": 256,
34
+ "codevector_dim": 768,
35
+ "conformer_conv_dropout": 0.1,
36
+ "contrastive_logits_temperature": 0.1,
37
+ "conv_bias": true,
38
+ "conv_depthwise_kernel_size": 31,
39
+ "conv_dim": [
40
+ 512,
41
+ 512,
42
+ 512,
43
+ 512,
44
+ 512,
45
+ 512,
46
+ 512
47
+ ],
48
+ "conv_kernel": [
49
+ 10,
50
+ 3,
51
+ 3,
52
+ 3,
53
+ 3,
54
+ 2,
55
+ 2
56
+ ],
57
+ "conv_stride": [
58
+ 5,
59
+ 2,
60
+ 2,
61
+ 2,
62
+ 2,
63
+ 2,
64
+ 2
65
+ ],
66
+ "ctc_loss_reduction": "sum",
67
+ "ctc_zero_infinity": false,
68
+ "diversity_loss_weight": 0.1,
69
+ "do_stable_layer_norm": true,
70
+ "eos_token_id": 2,
71
+ "feat_extract_activation": "gelu",
72
+ "feat_extract_dropout": 0.0,
73
+ "feat_extract_norm": "layer",
74
+ "feat_proj_dropout": 0.1,
75
+ "feat_quantizer_dropout": 0.0,
76
+ "final_dropout": 0.1,
77
+ "gradient_checkpointing": false,
78
+ "hidden_act": "swish",
79
+ "hidden_dropout": 0.1,
80
+ "hidden_dropout_prob": 0.1,
81
+ "hidden_size": 1024,
82
+ "initializer_range": 0.02,
83
+ "intermediate_size": 4096,
84
+ "layer_norm_eps": 1e-05,
85
+ "layerdrop": 0.0,
86
+ "mask_feature_length": 10,
87
+ "mask_feature_min_masks": 0,
88
+ "mask_feature_prob": 0.0,
89
+ "mask_time_length": 10,
90
+ "mask_time_min_masks": 2,
91
+ "mask_time_prob": 0.05,
92
+ "max_source_positions": 5000,
93
+ "model_type": "wav2vec2-conformer",
94
+ "num_adapter_layers": 3,
95
+ "num_attention_heads": 16,
96
+ "num_codevector_groups": 2,
97
+ "num_codevectors_per_group": 320,
98
+ "num_conv_pos_embedding_groups": 16,
99
+ "num_conv_pos_embeddings": 128,
100
+ "num_feat_extract_layers": 7,
101
+ "num_hidden_layers": 24,
102
+ "num_negatives": 100,
103
+ "output_hidden_size": 1024,
104
+ "pad_token_id": 0,
105
+ "position_embeddings_type": "rotary",
106
+ "proj_codevector_dim": 768,
107
+ "rotary_embedding_base": 10000,
108
+ "tdnn_dilation": [
109
+ 1,
110
+ 2,
111
+ 3,
112
+ 1,
113
+ 1
114
+ ],
115
+ "tdnn_dim": [
116
+ 512,
117
+ 512,
118
+ 512,
119
+ 512,
120
+ 1500
121
+ ],
122
+ "tdnn_kernel": [
123
+ 5,
124
+ 3,
125
+ 3,
126
+ 1,
127
+ 1
128
+ ],
129
+ "torch_dtype": "float32",
130
+ "transformers_version": "4.19.0.dev0",
131
+ "use_weighted_layer_sum": false,
132
+ "vocab_size": 32,
133
+ "xvector_output_dim": 512
134
+ },
135
+ "use_rvq_target": true,
136
+ "use_vq_target": false,
137
+ "use_encodec_target": false,
138
+ "rvq_ckpt_path": null,
139
+ "recon_loss_ratio": null,
140
+ "resume_checkpoint": null,
141
+ "rvq_n_codebooks": 8,
142
+ "rvq_multi_layer_num": 1
143
+ }
musicfm/.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # mac
2
+ .DS_Store
3
+
4
+ # cache
5
+ *.pyc
6
+
7
+ # data
8
+ *.json
9
+ *.pt
10
+
musicfm/LICENSE ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Dual Licensing Information
2
+ -------------------------
3
+
4
+ This software is dual-licensed under both the MIT License and the Apache License, Version 2.0.
5
+
6
+ - The file `modules/flash_conformer.py` is distributed under the terms of the Apache License, Version 2.0.
7
+ - All other files and modules in this software are distributed under the terms of the MIT License.
8
+
9
+ ### MIT License
10
+
11
+ Copyright 2023 ByteDance Inc.
12
+
13
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
14
+
15
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
18
+
19
+
20
+ ### Apache License, Version 2.0
21
+
22
+ Copyright 2018- The Hugging Face team. All rights reserved.
23
+
24
+ Apache License
25
+ Version 2.0, January 2004
26
+ http://www.apache.org/licenses/
27
+
28
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
29
+
30
+ 1. Definitions.
31
+
32
+ "License" shall mean the terms and conditions for use, reproduction,
33
+ and distribution as defined by Sections 1 through 9 of this document.
34
+
35
+ "Licensor" shall mean the copyright owner or entity authorized by
36
+ the copyright owner that is granting the License.
37
+
38
+ "Legal Entity" shall mean the union of the acting entity and all
39
+ other entities that control, are controlled by, or are under common
40
+ control with that entity. For the purposes of this definition,
41
+ "control" means (i) the power, direct or indirect, to cause the
42
+ direction or management of such entity, whether by contract or
43
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
44
+ outstanding shares, or (iii) beneficial ownership of such entity.
45
+
46
+ "You" (or "Your") shall mean an individual or Legal Entity
47
+ exercising permissions granted by this License.
48
+
49
+ "Source" form shall mean the preferred form for making modifications,
50
+ including but not limited to software source code, documentation
51
+ source, and configuration files.
52
+
53
+ "Object" form shall mean any form resulting from mechanical
54
+ transformation or translation of a Source form, including but
55
+ not limited to compiled object code, generated documentation,
56
+ and conversions to other media types.
57
+
58
+ "Work" shall mean the work of authorship, whether in Source or
59
+ Object form, made available under the License, as indicated by a
60
+ copyright notice that is included in or attached to the work
61
+ (an example is provided in the Appendix below).
62
+
63
+ "Derivative Works" shall mean any work, whether in Source or Object
64
+ form, that is based on (or derived from) the Work and for which the
65
+ editorial revisions, annotations, elaborations, or other modifications
66
+ represent, as a whole, an original work of authorship. For the purposes
67
+ of this License, Derivative Works shall not include works that remain
68
+ separable from, or merely link (or bind by name) to the interfaces of,
69
+ the Work and Derivative Works thereof.
70
+
71
+ "Contribution" shall mean any work of authorship, including
72
+ the original version of the Work and any modifications or additions
73
+ to that Work or Derivative Works thereof, that is intentionally
74
+ submitted to Licensor for inclusion in the Work by the copyright owner
75
+ or by an individual or Legal Entity authorized to submit on behalf of
76
+ the copyright owner. For the purposes of this definition, "submitted"
77
+ means any form of electronic, verbal, or written communication sent
78
+ to the Licensor or its representatives, including but not limited to
79
+ communication on electronic mailing lists, source code control systems,
80
+ and issue tracking systems that are managed by, or on behalf of, the
81
+ Licensor for the purpose of discussing and improving the Work, but
82
+ excluding communication that is conspicuously marked or otherwise
83
+ designated in writing by the copyright owner as "Not a Contribution."
84
+
85
+ "Contributor" shall mean Licensor and any individual or Legal Entity
86
+ on behalf of whom a Contribution has been received by Licensor and
87
+ subsequently incorporated within the Work.
88
+
89
+ 2. Grant of Copyright License. Subject to the terms and conditions of
90
+ this License, each Contributor hereby grants to You a perpetual,
91
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
92
+ copyright license to reproduce, prepare Derivative Works of,
93
+ publicly display, publicly perform, sublicense, and distribute the
94
+ Work and such Derivative Works in Source or Object form.
95
+
96
+ 3. Grant of Patent License. Subject to the terms and conditions of
97
+ this License, each Contributor hereby grants to You a perpetual,
98
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
99
+ (except as stated in this section) patent license to make, have made,
100
+ use, offer to sell, sell, import, and otherwise transfer the Work,
101
+ where such license applies only to those patent claims licensable
102
+ by such Contributor that are necessarily infringed by their
103
+ Contribution(s) alone or by combination of their Contribution(s)
104
+ with the Work to which such Contribution(s) was submitted. If You
105
+ institute patent litigation against any entity (including a
106
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
107
+ or a Contribution incorporated within the Work constitutes direct
108
+ or contributory patent infringement, then any patent licenses
109
+ granted to You under this License for that Work shall terminate
110
+ as of the date such litigation is filed.
111
+
112
+ 4. Redistribution. You may reproduce and distribute copies of the
113
+ Work or Derivative Works thereof in any medium, with or without
114
+ modifications, and in Source or Object form, provided that You
115
+ meet the following conditions:
116
+
117
+ (a) You must give any other recipients of the Work or
118
+ Derivative Works a copy of this License; and
119
+
120
+ (b) You must cause any modified files to carry prominent notices
121
+ stating that You changed the files; and
122
+
123
+ (c) You must retain, in the Source form of any Derivative Works
124
+ that You distribute, all copyright, patent, trademark, and
125
+ attribution notices from the Source form of the Work,
126
+ excluding those notices that do not pertain to any part of
127
+ the Derivative Works; and
128
+
129
+ (d) If the Work includes a "NOTICE" text file as part of its
130
+ distribution, then any Derivative Works that You distribute must
131
+ include a readable copy of the attribution notices contained
132
+ within such NOTICE file, excluding those notices that do not
133
+ pertain to any part of the Derivative Works, in at least one
134
+ of the following places: within a NOTICE text file distributed
135
+ as part of the Derivative Works; within the Source form or
136
+ documentation, if provided along with the Derivative Works; or,
137
+ within a display generated by the Derivative Works, if and
138
+ wherever such third-party notices normally appear. The contents
139
+ of the NOTICE file are for informational purposes only and
140
+ do not modify the License. You may add Your own attribution
141
+ notices within Derivative Works that You distribute, alongside
142
+ or as an addendum to the NOTICE text from the Work, provided
143
+ that such additional attribution notices cannot be construed
144
+ as modifying the License.
145
+
146
+ You may add Your own copyright statement to Your modifications and
147
+ may provide additional or different license terms and conditions
148
+ for use, reproduction, or distribution of Your modifications, or
149
+ for any such Derivative Works as a whole, provided Your use,
150
+ reproduction, and distribution of the Work otherwise complies with
151
+ the conditions stated in this License.
152
+
153
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
154
+ any Contribution intentionally submitted for inclusion in the Work
155
+ by You to the Licensor shall be under the terms and conditions of
156
+ this License, without any additional terms or conditions.
157
+ Notwithstanding the above, nothing herein shall supersede or modify
158
+ the terms of any separate license agreement you may have executed
159
+ with Licensor regarding such Contributions.
160
+
161
+ 6. Trademarks. This License does not grant permission to use the trade
162
+ names, trademarks, service marks, or product names of the Licensor,
163
+ except as required for reasonable and customary use in describing the
164
+ origin of the Work and reproducing the content of the NOTICE file.
165
+
166
+ 7. Disclaimer of Warranty. Unless required by applicable law or
167
+ agreed to in writing, Licensor provides the Work (and each
168
+ Contributor provides its Contributions) on an "AS IS" BASIS,
169
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
170
+ implied, including, without limitation, any warranties or conditions
171
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
172
+ PARTICULAR PURPOSE. You are solely responsible for determining the
173
+ appropriateness of using or redistributing the Work and assume any
174
+ risks associated with Your exercise of permissions under this License.
175
+
176
+ 8. Limitation of Liability. In no event and under no legal theory,
177
+ whether in tort (including negligence), contract, or otherwise,
178
+ unless required by applicable law (such as deliberate and grossly
179
+ negligent acts) or agreed to in writing, shall any Contributor be
180
+ liable to You for damages, including any direct, indirect, special,
181
+ incidental, or consequential damages of any character arising as a
182
+ result of this License or out of the use or inability to use the
183
+ Work (including but not limited to damages for loss of goodwill,
184
+ work stoppage, computer failure or malfunction, or any and all
185
+ other commercial damages or losses), even if such Contributor
186
+ has been advised of the possibility of such damages.
187
+
188
+ 9. Accepting Warranty or Additional Liability. While redistributing
189
+ the Work or Derivative Works thereof, You may choose to offer,
190
+ and charge a fee for, acceptance of support, warranty, indemnity,
191
+ or other liability obligations and/or rights consistent with this
192
+ License. However, in accepting such obligations, You may act only
193
+ on Your own behalf and on Your sole responsibility, not on behalf
194
+ of any other Contributor, and only if You agree to indemnify,
195
+ defend, and hold each Contributor harmless for any liability
196
+ incurred by, or claims asserted against, such Contributor by reason
197
+ of your accepting any such warranty or additional liability.
198
+
199
+ END OF TERMS AND CONDITIONS
200
+
201
+ APPENDIX: How to apply the Apache License to your work.
202
+
203
+ To apply the Apache License to your work, attach the following
204
+ boilerplate notice, with the fields enclosed by brackets "[]"
205
+ replaced with your own identifying information. (Don't include
206
+ the brackets!) The text should be enclosed in the appropriate
207
+ comment syntax for the file format. We also recommend that a
208
+ file or class name and description of purpose be included on the
209
+ same "printed page" as the copyright notice for easier
210
+ identification within third-party archives.
211
+
212
+ Copyright [yyyy] [name of copyright owner]
213
+
214
+ Licensed under the Apache License, Version 2.0 (the "License");
215
+ you may not use this file except in compliance with the License.
216
+ You may obtain a copy of the License at
217
+
218
+ http://www.apache.org/licenses/LICENSE-2.0
219
+
220
+ Unless required by applicable law or agreed to in writing, software
221
+ distributed under the License is distributed on an "AS IS" BASIS,
222
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
223
+ See the License for the specific language governing permissions and
224
+ limitations under the License.
musicfm/README.md ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MusicFM 🤖
2
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
3
+ [![License](https://img.shields.io/github/license/openshift/source-to-image.svg)](https://www.apache.org/licenses/LICENSE-2.0.html)
4
+
5
+
6
+ **A Foundation Model for Music Informatics**, ICASSP 2024 [[paper](https://arxiv.org/abs/2311.03318)]
7
+
8
+ -- Minz Won, Yun-Ning Hung, and Duc Le
9
+
10
+
11
+ ## Quick start
12
+ ### Download models
13
+
14
+ **MusicFM-FMA**
15
+
16
+ - Pretrained using [FMA-large](https://github.com/mdeff/fma) data
17
+
18
+ ```
19
+ wget -P YOUR_HOME_PATH/musicfm/data/ https://huggingface.co/minzwon/MusicFM/resolve/main/fma_stats.json
20
+ wget -P YOUR_HOME_PATH/musicfm/data/ https://huggingface.co/minzwon/MusicFM/resolve/main/pretrained_fma.pt
21
+ ```
22
+ ⚠️ The model checkpoint prior to Feb 13, 2024, was incorrect. Please ensure to re-download these files if you've been using previous versions.
23
+
24
+
25
+ **MusicFM-MSD**
26
+
27
+ - Pretrained with the entire [Million Song Dataset](http://millionsongdataset.com/)
28
+ - This version performs better than the FMA version
29
+ - This version is not introduced in the paper
30
+
31
+ ```
32
+ wget -P YOUR_HOME_PATH/musicfm/data/ https://huggingface.co/minzwon/MusicFM/resolve/main/msd_stats.json
33
+ wget -P YOUR_HOME_PATH/musicfm/data/ https://huggingface.co/minzwon/MusicFM/resolve/main/pretrained_msd.pt
34
+ ```
35
+
36
+ ### Get embeddings
37
+ ```
38
+ HOME_PATH = "/home/dev" # path where you cloned musicfm
39
+
40
+ import os
41
+ import sys
42
+ import torch
43
+
44
+ sys.path.append(HOME_PATH)
45
+ from musicfm.model.musicfm_25hz import MusicFM25Hz
46
+
47
+ # dummy audio (30 seconds, 24kHz)
48
+ wav = (torch.rand(4, 24000 * 30) - 0.5) * 2
49
+
50
+ # load MusicFM
51
+ musicfm = MusicFM25Hz(
52
+ is_flash=False,
53
+ stat_path=os.path.join(HOME_PATH, "musicfm", "data", "msd_stats.json"),
54
+ model_path=os.path.join(HOME_PATH, "musicfm", "data", "pretrained_msd.pt"),
55
+ )
56
+
57
+ # to GPUs
58
+ wav = wav.cuda()
59
+ musicfm = musicfm.cuda()
60
+
61
+ # get embeddings
62
+ musicfm.eval()
63
+ emb = musicfm.get_latent(wav, layer_ix=7)
64
+ ```
65
+
66
+ ### Mixed precision and Flash attention
67
+ Suffering from memory issues? [Mixed precision](https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html) and [Flash attention](https://arxiv.org/abs/2205.14135) will be good friends of yours!
68
+
69
+ ```
70
+ # dummy audio (30 seconds, 24kHz)
71
+ wav = (torch.rand(4, 24000 * 30) - 0.5) * 2
72
+
73
+ # load MusicFM
74
+ musicfm = MusicFM25Hz(is_flash=True)
75
+
76
+ # to GPUs
77
+ wav = wav.cuda().half()
78
+ musicfm = musicfm.cuda().half()
79
+
80
+ # get embeddings
81
+ musicfm.eval()
82
+ emb = musicfm.get_latent(wav, layer_ix=7)
83
+ ```
84
+
85
+ However, I highly recommend using `float32` for better performance in specific downstream tasks, such as beat tracking.
86
+
87
+ ### Usage in downstream tasks
88
+ The pretrained model operates at a 25Hz frame rate, but our downstream tasks demand varying temporal resolutions. To address this, we either summarize the sequence through global average pooling or adjust the temporal resolution using adaptive average pooling.
89
+
90
+ ```
91
+ from torch import nn
92
+
93
+ # Sequence-level representation
94
+ seq_emb = emb.mean(-1) # (batch, time, channel) -> (batch, channel)
95
+
96
+ # Frame-level representation
97
+ """
98
+ n_frame = desired_temporal_resolution * sequence_length_in_sec
99
+ 300 frames = 10Hz * 30s in this example
100
+ As a result, the sequence length becomes from 750 (25Hz * 30s) to 300
101
+ """
102
+ n_frame = 300
103
+ token_emb = nn.AdaptiveAvgPool1d(n_frame)(emb) # (batch, time, channel) -> (batch, time', channel)
104
+ ```
105
+ We share the details of our downstream evaluation as follows. The selection of input lengths and temporal resolutions is based on our prior experience with each task.
106
+
107
+ | | Beat | Chord | Structure | Key | Tagging |
108
+ | :--------: | :--------: | :--------: | :--------: | :--------: | :--------: |
109
+ | Input length | 6s | 12s | 24s | 12s | 29.1s |
110
+ | Temporal resolution | 50Hz | 16Hz | 8Hz | 0.5Hz | - |
111
+ | n_frame | 300 | 192 | 192 | 6 | 1 |
112
+
113
+ ### Fine-tuning
114
+ You can expect better performance in downstream tasks by fine-tuning the foundation model. In this scenario, employ `musicfm.train()` and extract the final embeddings by setting `layer_ix=12`. However, when optimizing the model with the same learning rate, there's a risk of [catastrophic forgetting](https://en.wikipedia.org/wiki/Catastrophic_interference). To mitigate this issue, we utilized a learning rate of 1e-5 for the foundation model and 1e-4 for the probing layers.
115
+
116
+
117
+
118
+ ## Results
119
+
120
+ <img src="figs/Table1.png" width="800">
121
+
122
+ \* FM1 is pretrained [MERT](https://arxiv.org/abs/2306.00107).
123
+
124
+ \*\*FM8 mirrors the [BEST-RQ](https://arxiv.org/abs/2202.01855) but with the distinction that it was trained using music data.
125
+
126
+
127
+ - Random tokenization generalizes well to music data.
128
+
129
+ - Frame-level classification offers a more comprehensive understanding of foundation models. While FM4 excels in music tagging, its performance in structural analysis is subpar.
130
+
131
+ - Input length used during training is critical for capturing
132
+ long-term contexts. Check 5s models (FM1, FM2, and FM4) and a 30s model (FM5) in downbeat tracking and structure analysis.
133
+
134
+ - Temporal resolution has less impact in our experimental setup. See FM5, FM6, and FM7.
135
+
136
+ - Model architecture makes a significant difference. Conformer (FM5) consistently outperformed BERT encoder (FM3) for across all downstream tasks.
137
+
138
+ - The influence of model size was relatively minimal (FM7 and FM8). However, we observed that FM8's performance continued to improve, which is typically indicative of underfitting. All models were trained for two weeks to ensure a fair comparison.
139
+
140
+ - Data is undeniably crucial, as in any data-driven approach. Please compare FM7 and FM9.
141
+
142
+ - Fine-tuning the foundation model further enhances downstream performance. However, we did observe a performance
143
+ drop in the tagging task, primarily attributed to overfitting.
144
+
145
+ ## Masked token modeling
146
+ <img src="figs/Fig1.png" width="300">
147
+
148
+ MusicFM follows the training scheme of [BEST-RQ](https://arxiv.org/abs/2202.01855). Input audio is masked with noise, and the model predicts the masked representation. Target tokens are generated by random projection and a random codebook. Both the projection layer and codebook are **randomly initialized** and remain **non-trainable**. Isn't it fascinating?
149
+
150
+ Note that input normalization is exceptionally crucial, considering the usage of random projection. You can check the details [here](https://github.com/minzwon/musicfm/blob/d5d0f313add9f3c32c41f95521760b1a136809ed/model/musicfm_25hz.py#L148).
151
+
152
+ ## Limitations
153
+ - Self-supervised foundation models in music, such as [JukeMIR](https://arxiv.org/abs/2107.05677), [MERT](https://arxiv.org/abs/2306.00107), and [MusicFM](https://arxiv.org/abs/2311.03318), consistently report relatively low performance in key detection. While fine-tuning the model can help bridge the performance gap, the foundation model itself does not appear to learn musical keys inherently. Further investigation is required to develop more advanced music foundation models.
154
+
155
+ - We share our model trained with the [FMA Dataset](https://github.com/mdeff/fma), which comprises 8k hours of Creative Common-licensed audio. While using a larger dataset (160k hours) can enhance performance, we've chosen to release the model trained on FMA to avoid potential licensing complications.
156
+
157
+ - Fine-tuned models for downstream tasks are not made publicly available as they are primarily used for evaluation purposes. It is expected that carefully designed backends beyond simple probing layers will improve downstream performance. I look forward to the contributions of other researchers with more expertise in each specific task.
158
+
159
+ - The downstream evaluation pipeline is not provided in this repository. Nonetheless, I believe creating a comprehensive evaluation pipeline is essential to expedite progress in music informatics research. I'm very open to discussing it together.
160
+
161
+
162
+ ## Acknowledgement
163
+ We acknowledge and extend our sincere gratitude to Ju-Chiang Wang for his valuable contributions to data refinement and providing a crucial codebase for our downstream evaluation.
164
+
165
+ ## Citation
166
+ ```
167
+ @article{won2023musicfm,
168
+ title={A Foundation Model for Music Informatics},
169
+ author = {Won, Minz and Hung, Yun-Ning and Le, Duc},
170
+ journal={arXiv preprint arXiv:2311.03318},
171
+ year={2023}
172
+ }
173
+ ```
musicfm/data/.gitkeep ADDED
File without changes
musicfm/figs/Fig1.png ADDED

Git LFS Details

  • SHA256: bbbb7a435402555125e996c747a619585906bc2cb7911afa5521ac35af1201e3
  • Pointer size: 131 Bytes
  • Size of remote file: 396 kB
musicfm/figs/Table1.png ADDED

Git LFS Details

  • SHA256: 773da09077da92f3e41fb9a53aff4efc559dcfc07d2e65b142cf23af2de512d7
  • Pointer size: 131 Bytes
  • Size of remote file: 807 kB
musicfm/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+
musicfm/model/musicfm_25hz.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright 2023 ByteDance Inc.
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
6
+ # to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
8
+ #
9
+ # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
10
+ #
11
+ # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
12
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
13
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
14
+ # IN THE SOFTWARE.
15
+
16
+ import json
17
+ import random
18
+ import torch
19
+ from torch import nn
20
+ from einops import rearrange
21
+
22
+ from musicfm.modules.random_quantizer import RandomProjectionQuantizer
23
+ from musicfm.modules.features import MelSTFT
24
+ from musicfm.modules.conv import Conv2dSubsampling
25
+
26
+
27
+ class MusicFM25Hz(nn.Module):
28
+ """
29
+ MusicFM
30
+
31
+ Input: 128-band mel spectrogram
32
+ Frontend: 2-layer Residual convolution
33
+ Backend: 12-layer Conformer
34
+ Quantizer: a codebook for mel spectrogram
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ num_codebooks=1,
40
+ codebook_dim=16,
41
+ codebook_size=4096,
42
+ features=["melspec_2048"],
43
+ hop_length=240,
44
+ n_mels=128,
45
+ conv_dim=512,
46
+ encoder_dim=1024,
47
+ encoder_depth=12,
48
+ mask_hop=0.4,
49
+ mask_prob=0.6,
50
+ is_flash=False,
51
+ stat_path="./data/fma_stats.json",
52
+ # model_path="./data/pretrained_fma.pt",
53
+ ):
54
+ super(MusicFM25Hz, self).__init__()
55
+
56
+ # global variables
57
+ self.hop_length = hop_length
58
+ self.mask_hop = mask_hop
59
+ self.mask_prob = mask_prob
60
+ self.num_codebooks = num_codebooks
61
+ self.codebook_size = codebook_size
62
+ self.features = features
63
+
64
+ # load feature mean / std stats
65
+ with open(stat_path, "r") as f:
66
+ self.stat = json.load(f)
67
+
68
+ # feature extractor
69
+ self.preprocessor_melspec_2048 = MelSTFT(
70
+ n_fft=2048, hop_length=hop_length, is_db=True
71
+ )
72
+
73
+ # random quantizer
74
+ seed = 142
75
+ for feature in self.features:
76
+ for i in range(num_codebooks):
77
+ setattr(
78
+ self,
79
+ f"quantizer_{feature}_{i}",
80
+ RandomProjectionQuantizer(
81
+ n_mels * 4, codebook_dim, codebook_size, seed=seed + i
82
+ ),
83
+ )
84
+
85
+ # two residual convolution layers + one projection layer
86
+ self.conv = Conv2dSubsampling(
87
+ 1, conv_dim, encoder_dim, strides=[2, 2], n_bands=n_mels
88
+ )
89
+
90
+ # Conformer
91
+ if is_flash:
92
+ from modules.flash_conformer import (
93
+ Wav2Vec2ConformerEncoder,
94
+ Wav2Vec2ConformerConfig,
95
+ )
96
+ else:
97
+ from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
98
+ Wav2Vec2ConformerEncoder,
99
+ Wav2Vec2ConformerConfig,
100
+ )
101
+ config = Wav2Vec2ConformerConfig.from_pretrained(
102
+ "facebook/wav2vec2-conformer-rope-large-960h-ft"
103
+ )
104
+ config.num_hidden_layers = encoder_depth
105
+ config.hidden_size = encoder_dim
106
+
107
+ self.conformer = Wav2Vec2ConformerEncoder(config)
108
+
109
+ # projection
110
+ self.linear = nn.Linear(encoder_dim, codebook_size)
111
+
112
+ # loss function
113
+ self.loss = nn.CrossEntropyLoss()
114
+
115
+ # cls token (used for sequence classification)
116
+ random.seed(seed)
117
+ self.cls_token = nn.Parameter(torch.randn(encoder_dim))
118
+
119
+ # load model
120
+ # if model_path:
121
+ # S = torch.load(model_path)["state_dict"]
122
+ # SS = {k[6:]: v for k, v in S.items()}
123
+ # self.load_state_dict(SS, strict=True)
124
+
125
+ def masking(self, x):
126
+ """random masking of 400ms with given probability"""
127
+ mx = x.clone()
128
+ b, t = mx.shape
129
+ len_masking_raw = int(24000 * self.mask_hop)
130
+ len_masking_token = int(24000 / self.hop_length / 2 / 2 * self.mask_hop)
131
+
132
+ # get random mask indices
133
+ start_indices = torch.rand(b, t // len_masking_raw) < self.mask_prob
134
+ time_domain_masked_indices = torch.nonzero(
135
+ start_indices.repeat_interleave(len_masking_raw, dim=1)
136
+ )
137
+ token_domain_masked_indices = torch.nonzero(
138
+ start_indices.repeat_interleave(len_masking_token, dim=1)
139
+ )
140
+
141
+ # mask with random values
142
+ masking_noise = (
143
+ torch.randn(time_domain_masked_indices.shape[0], dtype=x.dtype) * 0.1
144
+ ) # 0 mean 0.1 std
145
+ mx[tuple(time_domain_masked_indices.t())] = masking_noise.to(x.device)
146
+
147
+ return mx, token_domain_masked_indices
148
+
149
+ @torch.no_grad()
150
+ def preprocessing(self, x, features):
151
+ """extract classic audio features"""
152
+ # check precision
153
+ if x.dtype == torch.float16:
154
+ precision = 16
155
+ else:
156
+ precision = 32
157
+
158
+ out = {}
159
+ for key in features:
160
+ layer = getattr(self, "preprocessor_%s" % key)
161
+ out[key] = layer.float()(x.float())[..., :-1]
162
+ if precision == 16:
163
+ out[key] = out[key].half()
164
+ return out
165
+
166
+ def encoder(self, x):
167
+ """2-layer conv + w2v-conformer"""
168
+ x = self.conv(x)
169
+ out = self.conformer(x, output_hidden_states=True)
170
+ hidden_emb = out["hidden_states"]
171
+ last_emb = out["last_hidden_state"]
172
+ logits = self.linear(last_emb)
173
+ logits = {
174
+ key: logits[:, :, i * self.codebook_size : (i + 1) * self.codebook_size]
175
+ for i, key in enumerate(self.features)
176
+ }
177
+ return logits, hidden_emb
178
+
179
+ @torch.no_grad()
180
+ def normalize(self, x):
181
+ """normalize the input audio to have zero mean unit variance"""
182
+ for key in x.keys():
183
+ x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key]
184
+ return x
185
+
186
+ @torch.no_grad()
187
+ def rearrange(self, x):
188
+ """rearrange the batch to flatten every 4 steps"""
189
+ for key in x.keys():
190
+ if key == "chromagram":
191
+ x[key] = rearrange(x[key], "b f t -> b t f")
192
+ else:
193
+ x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=4)
194
+ return x
195
+
196
+ @torch.no_grad()
197
+ def tokenize(self, x):
198
+ out = {}
199
+ for key in x.keys():
200
+ layer = getattr(self, "quantizer_%s" % key)
201
+ out[key] = layer(x[key])
202
+ return out
203
+
204
+ def get_targets(self, x):
205
+ x = self.preprocessing(x, features=self.features)
206
+ x = self.normalize(x)
207
+ x = self.rearrange(x)
208
+ target_tokens = self.tokenize(x)
209
+ return target_tokens
210
+
211
+ def get_predictions(self, x):
212
+ # preprocessing
213
+ x = self.preprocessing(x, features=["melspec_2048"])
214
+ x = self.normalize(x)
215
+
216
+ # encoding
217
+ logits, hidden_emb = self.encoder(x["melspec_2048"])
218
+
219
+ return logits, hidden_emb
220
+
221
+ def get_latent(self, x, layer_ix=12):
222
+ _, hidden_states = self.get_predictions(x)
223
+ emb = hidden_states[layer_ix]
224
+ return emb
225
+
226
+ def get_loss(self, logits, target_tokens, masked_indices):
227
+ losses = {}
228
+ accuracies = {}
229
+ for key in logits.keys():
230
+ masked_logits = logits[key][tuple(masked_indices.t())]
231
+ masked_tokens = target_tokens[key][tuple(masked_indices.t())]
232
+ losses[key] = self.loss(masked_logits, masked_tokens)
233
+ accuracies[key] = (
234
+ torch.sum(masked_logits.argmax(-1) == masked_tokens)
235
+ / masked_tokens.numel()
236
+ )
237
+ return losses, accuracies
238
+
239
+ def forward(self, x):
240
+ # get target feature tokens
241
+ target_tokens = self.get_targets(x)
242
+
243
+ # masking
244
+ x, masked_indices = self.masking(x)
245
+
246
+ # forward
247
+ logits, hidden_emb = self.get_predictions(x)
248
+
249
+ # get loss
250
+ losses, accuracies = self.get_loss(logits, target_tokens, masked_indices)
251
+
252
+ return logits, hidden_emb, losses, accuracies
musicfm/modules/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+
musicfm/modules/conv.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright 2023 ByteDance Inc.
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
6
+ # to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
8
+ #
9
+ # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
10
+ #
11
+ # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
12
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
13
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
14
+ # IN THE SOFTWARE.
15
+
16
+ from torch import nn
17
+ from einops import rearrange
18
+
19
+
20
+ class Res2dModule(nn.Module):
21
+ def __init__(self, idim, odim, stride=(2, 2)):
22
+ super(Res2dModule, self).__init__()
23
+ self.conv1 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
24
+ self.bn1 = nn.BatchNorm2d(odim)
25
+ self.conv2 = nn.Conv2d(odim, odim, 3, padding=1)
26
+ self.bn2 = nn.BatchNorm2d(odim)
27
+ self.relu = nn.ReLU()
28
+
29
+ # residual
30
+ self.diff = False
31
+ if (idim != odim) or (stride[0] > 1):
32
+ self.conv3 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
33
+ self.bn3 = nn.BatchNorm2d(odim)
34
+ self.diff = True
35
+
36
+ def forward(self, x):
37
+ out = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x)))))
38
+ if self.diff:
39
+ x = self.bn3(self.conv3(x))
40
+ out = x + out
41
+ out = self.relu(out)
42
+ return out
43
+
44
+
45
+ class Conv2dSubsampling(nn.Module):
46
+ """Convolutional 2D subsampling (to 1/4 length).
47
+
48
+ Args:
49
+ idim (int): Input dimension.
50
+ hdim (int): Hidden dimension.
51
+ odim (int): Output dimension.
52
+ strides (list): Sizes of strides.
53
+ n_bands (int): Number of frequency bands.
54
+ """
55
+
56
+ def __init__(self, idim, hdim, odim, strides=[2, 2], n_bands=64):
57
+ """Construct an Conv2dSubsampling object."""
58
+ super(Conv2dSubsampling, self).__init__()
59
+
60
+ self.conv = nn.Sequential(
61
+ Res2dModule(idim, hdim, (2, strides[0])),
62
+ Res2dModule(hdim, hdim, (2, strides[1])),
63
+ )
64
+ self.linear = nn.Linear(hdim * n_bands // 2 // 2, odim)
65
+
66
+ def forward(self, x):
67
+ """Subsample x.
68
+
69
+ Args:
70
+ x (torch.Tensor): Input tensor (#batch, idim, time).
71
+
72
+ Returns:
73
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
74
+ where time' = time // 4.
75
+ """
76
+
77
+ if x.dim() == 3:
78
+ x = x.unsqueeze(1) # (b, c, f, t)
79
+ x = self.conv(x)
80
+ x = rearrange(x, "b c f t -> b t (c f)")
81
+ x = self.linear(x)
82
+ return x
musicfm/modules/features.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright 2023 ByteDance Inc.
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
6
+ # to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
8
+ #
9
+ # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
10
+ #
11
+ # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
12
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
13
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
14
+ # IN THE SOFTWARE.
15
+
16
+ import torchaudio
17
+ from torch import nn
18
+
19
+
20
+ class MelSTFT(nn.Module):
21
+ def __init__(
22
+ self,
23
+ sample_rate=24000,
24
+ n_fft=2048,
25
+ hop_length=240,
26
+ n_mels=128,
27
+ is_db=False,
28
+ ):
29
+ super(MelSTFT, self).__init__()
30
+
31
+ # spectrogram
32
+ self.mel_stft = torchaudio.transforms.MelSpectrogram(
33
+ sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels
34
+ )
35
+
36
+ # amplitude to decibel
37
+ self.is_db = is_db
38
+ if is_db:
39
+ self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
40
+
41
+ def forward(self, waveform):
42
+ if self.is_db:
43
+ return self.amplitude_to_db(self.mel_stft(waveform))
44
+ else:
45
+ return self.mel_stft(waveform)
musicfm/modules/flash_conformer.py ADDED
@@ -0,0 +1,2114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch Wav2Vec2-Conformer model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from torch.nn import CrossEntropyLoss
26
+ from torch.nn import functional as F
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.deepspeed import is_deepspeed_zero3_enabled
30
+ from transformers.modeling_outputs import (
31
+ BaseModelOutput,
32
+ CausalLMOutput,
33
+ SequenceClassifierOutput,
34
+ TokenClassifierOutput,
35
+ Wav2Vec2BaseModelOutput,
36
+ XVectorOutput,
37
+ )
38
+ from transformers.modeling_utils import PreTrainedModel
39
+ from transformers.utils import (
40
+ ModelOutput,
41
+ add_code_sample_docstrings,
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from transformers.models.wav2vec2_conformer.configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+
53
+ _HIDDEN_STATES_START_POSITION = 2
54
+
55
+ # General docstring
56
+ _CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig"
57
+
58
+ # Base docstring
59
+ _CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft"
60
+ _EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
61
+
62
+ # CTC docstring
63
+ _CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
64
+ _CTC_EXPECTED_LOSS = 64.21
65
+
66
+
67
+ WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
68
+ "facebook/wav2vec2-conformer-rel-pos-large",
69
+ # See all Wav2Vec2Conformer models at https://huggingface.co/models?filter=wav2vec2-conformer
70
+ ]
71
+
72
+
73
+ @dataclass
74
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput with Wav2Vec2->Wav2Vec2Conformer
75
+ class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput):
76
+ """
77
+ Output type of [`Wav2Vec2ConformerForPreTraining`], with potential hidden states and attentions.
78
+
79
+ Args:
80
+ loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
81
+ Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
82
+ paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
83
+ projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
84
+ Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
85
+ projected quantized states.
86
+ projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
87
+ Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
88
+ target vectors for contrastive loss.
89
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
90
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
91
+ shape `(batch_size, sequence_length, hidden_size)`.
92
+
93
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
94
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
95
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
96
+ sequence_length)`.
97
+
98
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
99
+ heads.
100
+ contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
101
+ The contrastive loss (L_m) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
102
+ diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
103
+ The diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
104
+ """
105
+
106
+ loss: Optional[torch.FloatTensor] = None
107
+ projected_states: torch.FloatTensor = None
108
+ projected_quantized_states: torch.FloatTensor = None
109
+ codevector_perplexity: torch.FloatTensor = None
110
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
111
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
112
+ contrastive_loss: Optional[torch.FloatTensor] = None
113
+ diversity_loss: Optional[torch.FloatTensor] = None
114
+
115
+
116
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
117
+ def _compute_mask_indices(
118
+ shape: Tuple[int, int],
119
+ mask_prob: float,
120
+ mask_length: int,
121
+ attention_mask: Optional[torch.LongTensor] = None,
122
+ min_masks: int = 0,
123
+ ) -> np.ndarray:
124
+ """
125
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
126
+ ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
127
+ CPU as part of the preprocessing during training.
128
+
129
+ Args:
130
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
131
+ the first element is the batch size and the second element is the length of the axis to span.
132
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
133
+ independently generated mask spans of length `mask_length` is computed by
134
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
135
+ actual percentage will be smaller.
136
+ mask_length: size of the mask
137
+ min_masks: minimum number of masked spans
138
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
139
+ each batch dimension.
140
+ """
141
+ batch_size, sequence_length = shape
142
+
143
+ if mask_length < 1:
144
+ raise ValueError("`mask_length` has to be bigger than 0.")
145
+
146
+ if mask_length > sequence_length:
147
+ raise ValueError(
148
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
149
+ f" and `sequence_length`: {sequence_length}`"
150
+ )
151
+
152
+ # epsilon is used for probabilistic rounding
153
+ epsilon = np.random.rand(1).item()
154
+
155
+ def compute_num_masked_span(input_length):
156
+ """Given input length, compute how many spans should be masked"""
157
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
158
+ num_masked_span = max(num_masked_span, min_masks)
159
+
160
+ # make sure num masked span <= sequence_length
161
+ if num_masked_span * mask_length > sequence_length:
162
+ num_masked_span = sequence_length // mask_length
163
+
164
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
165
+ if input_length - (mask_length - 1) < num_masked_span:
166
+ num_masked_span = max(input_length - (mask_length - 1), 0)
167
+
168
+ return num_masked_span
169
+
170
+ # compute number of masked spans in batch
171
+ input_lengths = (
172
+ attention_mask.sum(-1).detach().tolist()
173
+ if attention_mask is not None
174
+ else [sequence_length for _ in range(batch_size)]
175
+ )
176
+
177
+ # SpecAugment mask to fill
178
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
179
+ spec_aug_mask_idxs = []
180
+
181
+ max_num_masked_span = compute_num_masked_span(sequence_length)
182
+
183
+ if max_num_masked_span == 0:
184
+ return spec_aug_mask
185
+
186
+ for input_length in input_lengths:
187
+ # compute num of masked spans for this input
188
+ num_masked_span = compute_num_masked_span(input_length)
189
+
190
+ # get random indices to mask
191
+ spec_aug_mask_idx = np.random.choice(
192
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
193
+ )
194
+
195
+ # pick first sampled index that will serve as a dummy index to pad vector
196
+ # to ensure same dimension for all batches due to probabilistic rounding
197
+ # Picking first sample just pads those vectors twice.
198
+ if len(spec_aug_mask_idx) == 0:
199
+ # this case can only happen if `input_length` is strictly smaller then
200
+ # `sequence_length` in which case the last token has to be a padding
201
+ # token which we can use as a dummy mask id
202
+ dummy_mask_idx = sequence_length - 1
203
+ else:
204
+ dummy_mask_idx = spec_aug_mask_idx[0]
205
+
206
+ spec_aug_mask_idx = np.concatenate(
207
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
208
+ )
209
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
210
+
211
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
212
+
213
+ # expand masked indices to masked spans
214
+ spec_aug_mask_idxs = np.broadcast_to(
215
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
216
+ )
217
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
218
+
219
+ # add offset to the starting indexes so that indexes now create a span
220
+ offsets = np.arange(mask_length)[None, None, :]
221
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
222
+ batch_size, max_num_masked_span * mask_length
223
+ )
224
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
225
+
226
+ # ensure that we cannot have indices larger than sequence_length
227
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
228
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
229
+
230
+ # scatter indices to mask
231
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
232
+
233
+ return spec_aug_mask
234
+
235
+
236
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices
237
+ def _sample_negative_indices(
238
+ features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
239
+ ):
240
+ """
241
+ Sample `num_negatives` vectors from feature vectors.
242
+ """
243
+ batch_size, sequence_length = features_shape
244
+
245
+ # generate indices of the positive vectors themselves, repeat them `num_negatives` times
246
+ sequence_length_range = np.arange(sequence_length)
247
+
248
+ # get `num_negatives` random vector indices from the same utterance
249
+ sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
250
+
251
+ mask_time_indices = (
252
+ mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool)
253
+ )
254
+
255
+ for batch_idx in range(batch_size):
256
+ high = mask_time_indices[batch_idx].sum() - 1
257
+ mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
258
+
259
+ feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
260
+ sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
261
+ # avoid sampling the same positive vector, but keep the distribution uniform
262
+ sampled_indices[sampled_indices >= feature_indices] += 1
263
+
264
+ # remap to actual indices
265
+ sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
266
+
267
+ # correct for batch size
268
+ sampled_negative_indices[batch_idx] += batch_idx * sequence_length
269
+
270
+ return sampled_negative_indices
271
+
272
+
273
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
274
+ class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module):
275
+ def __init__(self, config, layer_id=0):
276
+ super().__init__()
277
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
278
+ self.out_conv_dim = config.conv_dim[layer_id]
279
+
280
+ self.conv = nn.Conv1d(
281
+ self.in_conv_dim,
282
+ self.out_conv_dim,
283
+ kernel_size=config.conv_kernel[layer_id],
284
+ stride=config.conv_stride[layer_id],
285
+ bias=config.conv_bias,
286
+ )
287
+ self.activation = ACT2FN[config.feat_extract_activation]
288
+
289
+ def forward(self, hidden_states):
290
+ hidden_states = self.conv(hidden_states)
291
+ hidden_states = self.activation(hidden_states)
292
+ return hidden_states
293
+
294
+
295
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
296
+ class Wav2Vec2ConformerLayerNormConvLayer(nn.Module):
297
+ def __init__(self, config, layer_id=0):
298
+ super().__init__()
299
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
300
+ self.out_conv_dim = config.conv_dim[layer_id]
301
+
302
+ self.conv = nn.Conv1d(
303
+ self.in_conv_dim,
304
+ self.out_conv_dim,
305
+ kernel_size=config.conv_kernel[layer_id],
306
+ stride=config.conv_stride[layer_id],
307
+ bias=config.conv_bias,
308
+ )
309
+ self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
310
+ self.activation = ACT2FN[config.feat_extract_activation]
311
+
312
+ def forward(self, hidden_states):
313
+ hidden_states = self.conv(hidden_states)
314
+
315
+ hidden_states = hidden_states.transpose(-2, -1)
316
+ hidden_states = self.layer_norm(hidden_states)
317
+ hidden_states = hidden_states.transpose(-2, -1)
318
+
319
+ hidden_states = self.activation(hidden_states)
320
+ return hidden_states
321
+
322
+
323
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
324
+ class Wav2Vec2ConformerGroupNormConvLayer(nn.Module):
325
+ def __init__(self, config, layer_id=0):
326
+ super().__init__()
327
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
328
+ self.out_conv_dim = config.conv_dim[layer_id]
329
+
330
+ self.conv = nn.Conv1d(
331
+ self.in_conv_dim,
332
+ self.out_conv_dim,
333
+ kernel_size=config.conv_kernel[layer_id],
334
+ stride=config.conv_stride[layer_id],
335
+ bias=config.conv_bias,
336
+ )
337
+ self.activation = ACT2FN[config.feat_extract_activation]
338
+
339
+ self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
340
+
341
+ def forward(self, hidden_states):
342
+ hidden_states = self.conv(hidden_states)
343
+ hidden_states = self.layer_norm(hidden_states)
344
+ hidden_states = self.activation(hidden_states)
345
+ return hidden_states
346
+
347
+
348
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Wav2Vec2Conformer
349
+ class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module):
350
+ def __init__(self, config):
351
+ super().__init__()
352
+ self.conv = nn.Conv1d(
353
+ config.hidden_size,
354
+ config.hidden_size,
355
+ kernel_size=config.num_conv_pos_embeddings,
356
+ padding=config.num_conv_pos_embeddings // 2,
357
+ groups=config.num_conv_pos_embedding_groups,
358
+ )
359
+
360
+ if is_deepspeed_zero3_enabled():
361
+ import deepspeed
362
+
363
+ with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
364
+ self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
365
+ deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
366
+ deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
367
+ else:
368
+ self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
369
+
370
+ self.padding = Wav2Vec2ConformerSamePadLayer(config.num_conv_pos_embeddings)
371
+ self.activation = ACT2FN[config.feat_extract_activation]
372
+
373
+ def forward(self, hidden_states):
374
+ hidden_states = hidden_states.transpose(1, 2)
375
+
376
+ hidden_states = self.conv(hidden_states)
377
+ hidden_states = self.padding(hidden_states)
378
+ hidden_states = self.activation(hidden_states)
379
+
380
+ hidden_states = hidden_states.transpose(1, 2)
381
+ return hidden_states
382
+
383
+
384
+ class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module):
385
+ """Rotary positional embedding
386
+ Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
387
+ """
388
+
389
+ def __init__(self, config):
390
+ super().__init__()
391
+ dim = config.hidden_size // config.num_attention_heads
392
+ base = config.rotary_embedding_base
393
+
394
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
395
+ self.register_buffer("inv_freq", inv_freq)
396
+ self.cached_sequence_length = None
397
+ self.cached_rotary_positional_embedding = None
398
+
399
+ def forward(self, hidden_states):
400
+ sequence_length = hidden_states.shape[1]
401
+
402
+ if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
403
+ return self.cached_rotary_positional_embedding
404
+
405
+ self.cached_sequence_length = sequence_length
406
+ time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
407
+ freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
408
+ embeddings = torch.cat((freqs, freqs), dim=-1)
409
+
410
+ cos_embeddings = embeddings.cos()[:, None, None, :]
411
+ sin_embeddings = embeddings.sin()[:, None, None, :]
412
+ self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings])
413
+ return self.cached_rotary_positional_embedding
414
+
415
+
416
+ class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
417
+ """Relative positional encoding module."""
418
+
419
+ def __init__(self, config):
420
+ super().__init__()
421
+ self.max_len = config.max_source_positions
422
+ self.d_model = config.hidden_size
423
+ self.pe = None
424
+ self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
425
+
426
+ def extend_pe(self, x):
427
+ # Reset the positional encodings
428
+ if self.pe is not None:
429
+ # self.pe contains both positive and negative parts
430
+ # the length of self.pe is 2 * input_len - 1
431
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
432
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
433
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
434
+ return
435
+ # Suppose `i` is the position of query vector and `j` is the
436
+ # position of key vector. We use positive relative positions when keys
437
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
438
+ pe_positive = torch.zeros(x.size(1), self.d_model)
439
+ pe_negative = torch.zeros(x.size(1), self.d_model)
440
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
441
+ div_term = torch.exp(
442
+ torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model)
443
+ )
444
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
445
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
446
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
447
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
448
+
449
+ # Reverse the order of positive indices and concat both positive and
450
+ # negative indices. This is used to support the shifting trick
451
+ # as in https://arxiv.org/abs/1901.02860
452
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
453
+ pe_negative = pe_negative[1:].unsqueeze(0)
454
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
455
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
456
+
457
+ def forward(self, hidden_states: torch.Tensor):
458
+ self.extend_pe(hidden_states)
459
+ start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1
460
+ end_idx = self.pe.size(1) // 2 + hidden_states.size(1)
461
+ relative_position_embeddings = self.pe[:, start_idx:end_idx]
462
+
463
+ return relative_position_embeddings
464
+
465
+
466
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Wav2Vec2Conformer
467
+ class Wav2Vec2ConformerSamePadLayer(nn.Module):
468
+ def __init__(self, num_conv_pos_embeddings):
469
+ super().__init__()
470
+ self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
471
+
472
+ def forward(self, hidden_states):
473
+ if self.num_pad_remove > 0:
474
+ hidden_states = hidden_states[:, :, : -self.num_pad_remove]
475
+ return hidden_states
476
+
477
+
478
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Wav2Vec2Conformer
479
+ class Wav2Vec2ConformerFeatureEncoder(nn.Module):
480
+ """Construct the features from raw audio waveform"""
481
+
482
+ def __init__(self, config):
483
+ super().__init__()
484
+
485
+ if config.feat_extract_norm == "group":
486
+ conv_layers = [Wav2Vec2ConformerGroupNormConvLayer(config, layer_id=0)] + [
487
+ Wav2Vec2ConformerNoLayerNormConvLayer(config, layer_id=i + 1)
488
+ for i in range(config.num_feat_extract_layers - 1)
489
+ ]
490
+ elif config.feat_extract_norm == "layer":
491
+ conv_layers = [
492
+ Wav2Vec2ConformerLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
493
+ ]
494
+ else:
495
+ raise ValueError(
496
+ f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
497
+ )
498
+ self.conv_layers = nn.ModuleList(conv_layers)
499
+ self.gradient_checkpointing = False
500
+ self._requires_grad = True
501
+
502
+ def _freeze_parameters(self):
503
+ for param in self.parameters():
504
+ param.requires_grad = False
505
+ self._requires_grad = False
506
+
507
+ def forward(self, input_values):
508
+ hidden_states = input_values[:, None]
509
+
510
+ # make sure hidden_states require grad for gradient_checkpointing
511
+ if self._requires_grad and self.training:
512
+ hidden_states.requires_grad = True
513
+
514
+ for conv_layer in self.conv_layers:
515
+ if self._requires_grad and self.gradient_checkpointing and self.training:
516
+
517
+ def create_custom_forward(module):
518
+ def custom_forward(*inputs):
519
+ return module(*inputs)
520
+
521
+ return custom_forward
522
+
523
+ hidden_states = torch.utils.checkpoint.checkpoint(
524
+ create_custom_forward(conv_layer),
525
+ hidden_states,
526
+ )
527
+ else:
528
+ hidden_states = conv_layer(hidden_states)
529
+
530
+ return hidden_states
531
+
532
+
533
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->Wav2Vec2Conformer
534
+ class Wav2Vec2ConformerFeatureProjection(nn.Module):
535
+ def __init__(self, config):
536
+ super().__init__()
537
+ self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
538
+ self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
539
+ self.dropout = nn.Dropout(config.feat_proj_dropout)
540
+
541
+ def forward(self, hidden_states):
542
+ # non-projected hidden states are needed for quantization
543
+ norm_hidden_states = self.layer_norm(hidden_states)
544
+ hidden_states = self.projection(norm_hidden_states)
545
+ hidden_states = self.dropout(hidden_states)
546
+ return hidden_states, norm_hidden_states
547
+
548
+
549
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Wav2Vec2Conformer
550
+ class Wav2Vec2ConformerFeedForward(nn.Module):
551
+ def __init__(self, config):
552
+ super().__init__()
553
+ self.intermediate_dropout = nn.Dropout(config.activation_dropout)
554
+
555
+ self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
556
+ if isinstance(config.hidden_act, str):
557
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
558
+ else:
559
+ self.intermediate_act_fn = config.hidden_act
560
+
561
+ self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
562
+ self.output_dropout = nn.Dropout(config.hidden_dropout)
563
+
564
+ def forward(self, hidden_states):
565
+ hidden_states = self.intermediate_dense(hidden_states)
566
+ hidden_states = self.intermediate_act_fn(hidden_states)
567
+ hidden_states = self.intermediate_dropout(hidden_states)
568
+
569
+ hidden_states = self.output_dense(hidden_states)
570
+ hidden_states = self.output_dropout(hidden_states)
571
+ return hidden_states
572
+
573
+
574
+ class Wav2Vec2ConformerConvolutionModule(nn.Module):
575
+ """Convolution block used in the conformer block"""
576
+
577
+ def __init__(self, config):
578
+ super().__init__()
579
+ if (config.conv_depthwise_kernel_size - 1) % 2 == 1:
580
+ raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding")
581
+ self.layer_norm = nn.LayerNorm(config.hidden_size)
582
+ self.pointwise_conv1 = torch.nn.Conv1d(
583
+ config.hidden_size,
584
+ 2 * config.hidden_size,
585
+ kernel_size=1,
586
+ stride=1,
587
+ padding=0,
588
+ bias=False,
589
+ )
590
+ self.glu = torch.nn.GLU(dim=1)
591
+ self.depthwise_conv = torch.nn.Conv1d(
592
+ config.hidden_size,
593
+ config.hidden_size,
594
+ config.conv_depthwise_kernel_size,
595
+ stride=1,
596
+ padding=(config.conv_depthwise_kernel_size - 1) // 2,
597
+ groups=config.hidden_size,
598
+ bias=False,
599
+ )
600
+ self.batch_norm = torch.nn.BatchNorm1d(config.hidden_size)
601
+ self.activation = ACT2FN[config.hidden_act]
602
+ self.pointwise_conv2 = torch.nn.Conv1d(
603
+ config.hidden_size,
604
+ config.hidden_size,
605
+ kernel_size=1,
606
+ stride=1,
607
+ padding=0,
608
+ bias=False,
609
+ )
610
+ self.dropout = torch.nn.Dropout(config.conformer_conv_dropout)
611
+
612
+ def forward(self, hidden_states):
613
+ hidden_states = self.layer_norm(hidden_states)
614
+ # exchange the temporal dimension and the feature dimension
615
+ hidden_states = hidden_states.transpose(1, 2)
616
+
617
+ # GLU mechanism
618
+ # => (batch, 2*channel, dim)
619
+ hidden_states = self.pointwise_conv1(hidden_states)
620
+ # => (batch, channel, dim)
621
+ hidden_states = self.glu(hidden_states)
622
+
623
+ # 1D Depthwise Conv
624
+ hidden_states = self.depthwise_conv(hidden_states)
625
+ hidden_states = self.batch_norm(hidden_states)
626
+ hidden_states = self.activation(hidden_states)
627
+
628
+ hidden_states = self.pointwise_conv2(hidden_states)
629
+ hidden_states = self.dropout(hidden_states)
630
+ hidden_states = hidden_states.transpose(1, 2)
631
+ return hidden_states
632
+
633
+
634
+ class Wav2Vec2ConformerSelfAttention(nn.Module):
635
+ """Construct an Wav2Vec2ConformerSelfAttention object.
636
+ Can be enhanced with rotary or relative position embeddings.
637
+ """
638
+
639
+ def __init__(self, config):
640
+ super().__init__()
641
+
642
+ self.head_size = config.hidden_size // config.num_attention_heads
643
+ self.num_heads = config.num_attention_heads
644
+ self.position_embeddings_type = config.position_embeddings_type
645
+
646
+ self.linear_q = nn.Linear(config.hidden_size, config.hidden_size)
647
+ self.linear_k = nn.Linear(config.hidden_size, config.hidden_size)
648
+ self.linear_v = nn.Linear(config.hidden_size, config.hidden_size)
649
+ self.linear_out = nn.Linear(config.hidden_size, config.hidden_size)
650
+
651
+ self.dropout = nn.Dropout(p=config.attention_dropout)
652
+ self.dropout_p = config.attention_dropout
653
+
654
+ self.is_causal = config.is_causal
655
+
656
+ if self.position_embeddings_type == "relative":
657
+ # linear transformation for positional encoding
658
+ self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
659
+ # these two learnable bias are used in matrix c and matrix d
660
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
661
+ self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
662
+ self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
663
+
664
+ def forward(
665
+ self,
666
+ hidden_states: torch.Tensor,
667
+ attention_mask: Optional[torch.Tensor] = None,
668
+ relative_position_embeddings: Optional[torch.Tensor] = None,
669
+ output_attentions: bool = False,
670
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
671
+ # self-attention mechanism
672
+ batch_size, sequence_length, hidden_size = hidden_states.size()
673
+
674
+ # make sure query/key states can be != value states
675
+ query_key_states = hidden_states
676
+ value_states = hidden_states
677
+
678
+ if self.position_embeddings_type == "rotary":
679
+ if relative_position_embeddings is None:
680
+ raise ValueError(
681
+ "`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'"
682
+ )
683
+ query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings)
684
+
685
+ # project query_key_states and value_states
686
+ query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
687
+ key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
688
+ value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size)
689
+
690
+ # => (batch, head, time1, d_k)
691
+ query = query.transpose(1, 2)
692
+ key = key.transpose(1, 2)
693
+ value = value.transpose(1, 2)
694
+
695
+ with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
696
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=self.dropout_p, is_causal=self.is_causal)
697
+ probs = None
698
+
699
+ # # apply attention_mask if necessary
700
+ # if attention_mask is not None:
701
+ # scores = scores + attention_mask
702
+
703
+ # # => (batch, head, time1, time2)
704
+ # probs = torch.softmax(scores, dim=-1)
705
+ # probs = self.dropout(probs)
706
+
707
+ # # => (batch, head, time1, d_k)
708
+ # hidden_states = torch.matmul(probs, value)
709
+
710
+ # => (batch, time1, hidden_size)
711
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size)
712
+ hidden_states = self.linear_out(hidden_states)
713
+
714
+ return hidden_states, probs
715
+
716
+ def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
717
+ batch_size, sequence_length, hidden_size = hidden_states.size()
718
+ hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
719
+
720
+ cos = relative_position_embeddings[0, :sequence_length, ...]
721
+ sin = relative_position_embeddings[1, :sequence_length, ...]
722
+
723
+ # rotate hidden_states with rotary embeddings
724
+ hidden_states = hidden_states.transpose(0, 1)
725
+ rotated_states_begin = hidden_states[..., : self.head_size // 2]
726
+ rotated_states_end = hidden_states[..., self.head_size // 2 :]
727
+ rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1)
728
+ hidden_states = (hidden_states * cos) + (rotated_states * sin)
729
+ hidden_states = hidden_states.transpose(0, 1)
730
+
731
+ hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size)
732
+
733
+ return hidden_states
734
+
735
+ def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
736
+ # 1. project positional embeddings
737
+ # => (batch, head, 2*time1-1, d_k)
738
+ proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings)
739
+ proj_relative_position_embeddings = proj_relative_position_embeddings.view(
740
+ relative_position_embeddings.size(0), -1, self.num_heads, self.head_size
741
+ )
742
+ proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2)
743
+ proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3)
744
+
745
+ # 2. Add bias to query
746
+ # => (batch, head, time1, d_k)
747
+ query = query.transpose(1, 2)
748
+ q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
749
+ q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
750
+
751
+ # 3. attention score: first compute matrix a and matrix c
752
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
753
+ # => (batch, head, time1, time2)
754
+ scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
755
+
756
+ # 4. then compute matrix b and matrix d
757
+ # => (batch, head, time1, 2*time1-1)
758
+ scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings)
759
+
760
+ # 5. shift matrix b and matrix d
761
+ zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype)
762
+ scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1)
763
+ scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2])
764
+ scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape)
765
+ scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd)
766
+ scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1]
767
+
768
+ # 6. sum matrices
769
+ # => (batch, head, time1, time2)
770
+ scores = (scores_ac + scores_bd) / math.sqrt(self.head_size)
771
+
772
+ return scores
773
+
774
+
775
+ class Wav2Vec2ConformerEncoderLayer(nn.Module):
776
+ """Conformer block based on https://arxiv.org/abs/2005.08100."""
777
+
778
+ def __init__(self, config):
779
+ super().__init__()
780
+ embed_dim = config.hidden_size
781
+ dropout = config.attention_dropout
782
+
783
+ # Feed-forward 1
784
+ self.ffn1_layer_norm = nn.LayerNorm(embed_dim)
785
+ self.ffn1 = Wav2Vec2ConformerFeedForward(config)
786
+
787
+ # Self-Attention
788
+ self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
789
+ self.self_attn_dropout = torch.nn.Dropout(dropout)
790
+ self.self_attn = Wav2Vec2ConformerSelfAttention(config)
791
+
792
+ # Conformer Convolution
793
+ self.conv_module = Wav2Vec2ConformerConvolutionModule(config)
794
+
795
+ # Feed-forward 2
796
+ self.ffn2_layer_norm = nn.LayerNorm(embed_dim)
797
+ self.ffn2 = Wav2Vec2ConformerFeedForward(config)
798
+ self.final_layer_norm = nn.LayerNorm(embed_dim)
799
+
800
+ def forward(
801
+ self,
802
+ hidden_states,
803
+ attention_mask: Optional[torch.Tensor] = None,
804
+ relative_position_embeddings: Optional[torch.Tensor] = None,
805
+ output_attentions: bool = False,
806
+ ):
807
+ hidden_states = hidden_states
808
+
809
+ # 1. Feed-Forward 1 layer
810
+ residual = hidden_states
811
+ hidden_states = self.ffn1_layer_norm(hidden_states)
812
+ hidden_states = self.ffn1(hidden_states)
813
+ hidden_states = hidden_states * 0.5 + residual
814
+ residual = hidden_states
815
+
816
+ # 2. Self-Attention layer
817
+ hidden_states = self.self_attn_layer_norm(hidden_states)
818
+ hidden_states, attn_weigts = self.self_attn(
819
+ hidden_states=hidden_states,
820
+ attention_mask=attention_mask,
821
+ relative_position_embeddings=relative_position_embeddings,
822
+ output_attentions=output_attentions,
823
+ )
824
+ hidden_states = self.self_attn_dropout(hidden_states)
825
+ hidden_states = hidden_states + residual
826
+
827
+ # 3. Convolutional Layer
828
+ residual = hidden_states
829
+ hidden_states = self.conv_module(hidden_states)
830
+ hidden_states = residual + hidden_states
831
+
832
+ # 4. Feed-Forward 2 Layer
833
+ residual = hidden_states
834
+ hidden_states = self.ffn2_layer_norm(hidden_states)
835
+ hidden_states = self.ffn2(hidden_states)
836
+ hidden_states = hidden_states * 0.5 + residual
837
+ hidden_states = self.final_layer_norm(hidden_states)
838
+
839
+ return hidden_states, attn_weigts
840
+
841
+
842
+ class Wav2Vec2ConformerEncoder(nn.Module):
843
+ def __init__(self, config, is_causal=False):
844
+ super().__init__()
845
+ config.is_causal = is_causal
846
+ self.config = config
847
+
848
+ if config.position_embeddings_type == "relative":
849
+ self.embed_positions = Wav2Vec2ConformerRelPositionalEmbedding(config)
850
+ elif config.position_embeddings_type == "rotary":
851
+ self.embed_positions = Wav2Vec2ConformerRotaryPositionalEmbedding(config)
852
+ else:
853
+ self.embed_positions = None
854
+
855
+ self.pos_conv_embed = Wav2Vec2ConformerPositionalConvEmbedding(config)
856
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
857
+ self.dropout = nn.Dropout(config.hidden_dropout)
858
+ self.layers = nn.ModuleList([Wav2Vec2ConformerEncoderLayer(config) for _ in range(config.num_hidden_layers)])
859
+ self.gradient_checkpointing = False
860
+
861
+ def forward(
862
+ self,
863
+ hidden_states,
864
+ attention_mask=None,
865
+ output_attentions=False,
866
+ output_hidden_states=False,
867
+ return_dict=True,
868
+ ):
869
+ all_hidden_states = () if output_hidden_states else None
870
+ all_self_attentions = () if output_attentions else None
871
+
872
+ if attention_mask is not None:
873
+ # make sure padded tokens output 0
874
+ hidden_states[~attention_mask] = 0.0
875
+
876
+ # extend attention_mask
877
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
878
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
879
+ attention_mask = attention_mask.expand(
880
+ attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
881
+ )
882
+
883
+ hidden_states = self.dropout(hidden_states)
884
+
885
+ if self.embed_positions is not None:
886
+ relative_position_embeddings = self.embed_positions(hidden_states)
887
+ else:
888
+ relative_position_embeddings = None
889
+
890
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
891
+
892
+ for i, layer in enumerate(self.layers):
893
+ if output_hidden_states:
894
+ all_hidden_states = all_hidden_states + (hidden_states,)
895
+
896
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
897
+ dropout_probability = np.random.uniform(0, 1)
898
+
899
+ skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
900
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
901
+ # under deepspeed zero3 all gpus must run in sync
902
+ if self.gradient_checkpointing and self.training:
903
+ # create gradient checkpointing function
904
+ def create_custom_forward(module):
905
+ def custom_forward(*inputs):
906
+ return module(*inputs, output_attentions)
907
+
908
+ return custom_forward
909
+
910
+ layer_outputs = torch.utils.checkpoint.checkpoint(
911
+ create_custom_forward(layer),
912
+ hidden_states,
913
+ attention_mask,
914
+ relative_position_embeddings,
915
+ )
916
+ else:
917
+ layer_outputs = layer(
918
+ hidden_states,
919
+ attention_mask=attention_mask,
920
+ relative_position_embeddings=relative_position_embeddings,
921
+ output_attentions=output_attentions,
922
+ )
923
+ hidden_states = layer_outputs[0]
924
+
925
+ if skip_the_layer:
926
+ layer_outputs = (None, None)
927
+
928
+ if output_attentions:
929
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
930
+
931
+ hidden_states = self.layer_norm(hidden_states)
932
+ if output_hidden_states:
933
+ all_hidden_states = all_hidden_states + (hidden_states,)
934
+
935
+ if not return_dict:
936
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
937
+ return BaseModelOutput(
938
+ last_hidden_state=hidden_states,
939
+ hidden_states=all_hidden_states,
940
+ attentions=all_self_attentions,
941
+ )
942
+
943
+
944
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GumbelVectorQuantizer with Wav2Vec2->Wav2Vec2Conformer
945
+ class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module):
946
+ """
947
+ Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
948
+ GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
949
+ """
950
+
951
+ def __init__(self, config):
952
+ super().__init__()
953
+ self.num_groups = config.num_codevector_groups
954
+ self.num_vars = config.num_codevectors_per_group
955
+
956
+ if config.codevector_dim % self.num_groups != 0:
957
+ raise ValueError(
958
+ f"`config.codevector_dim {config.codevector_dim} must be divisible "
959
+ f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
960
+ )
961
+
962
+ # storage for codebook variables (codewords)
963
+ self.codevectors = nn.Parameter(
964
+ torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
965
+ )
966
+ self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
967
+
968
+ # can be decayed for training
969
+ self.temperature = 2
970
+
971
+ @staticmethod
972
+ def _compute_perplexity(probs, mask=None):
973
+ if mask is not None:
974
+ mask_extended = mask.flatten()[:, None, None].expand(probs.shape)
975
+ probs = torch.where(mask_extended, probs, torch.zeros_like(probs))
976
+ marginal_probs = probs.sum(dim=0) / mask.sum()
977
+ else:
978
+ marginal_probs = probs.mean(dim=0)
979
+
980
+ perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
981
+ return perplexity
982
+
983
+ def forward(self, hidden_states, mask_time_indices=None):
984
+ batch_size, sequence_length, hidden_size = hidden_states.shape
985
+
986
+ # project to codevector dim
987
+ hidden_states = self.weight_proj(hidden_states)
988
+ hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
989
+
990
+ if self.training:
991
+ # sample code vector probs via gumbel in differentiateable way
992
+ codevector_probs = nn.functional.gumbel_softmax(
993
+ hidden_states.float(), tau=self.temperature, hard=True
994
+ ).type_as(hidden_states)
995
+
996
+ # compute perplexity
997
+ codevector_soft_dist = torch.softmax(
998
+ hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
999
+ )
1000
+ perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
1001
+ else:
1002
+ # take argmax in non-differentiable way
1003
+ # comptute hard codevector distribution (one hot)
1004
+ codevector_idx = hidden_states.argmax(dim=-1)
1005
+ codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_(
1006
+ -1, codevector_idx.view(-1, 1), 1.0
1007
+ )
1008
+ codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
1009
+
1010
+ perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
1011
+
1012
+ codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
1013
+ # use probs to retrieve codevectors
1014
+ codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
1015
+ codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
1016
+ codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
1017
+
1018
+ return codevectors, perplexity
1019
+
1020
+
1021
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->Wav2Vec2Conformer
1022
+ class Wav2Vec2ConformerAdapter(nn.Module):
1023
+ def __init__(self, config):
1024
+ super().__init__()
1025
+
1026
+ # feature dim might need to be down-projected
1027
+ if config.output_hidden_size != config.hidden_size:
1028
+ self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
1029
+ self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
1030
+ else:
1031
+ self.proj = self.proj_layer_norm = None
1032
+
1033
+ self.layers = nn.ModuleList(Wav2Vec2ConformerAdapterLayer(config) for _ in range(config.num_adapter_layers))
1034
+ self.layerdrop = config.layerdrop
1035
+
1036
+ def forward(self, hidden_states):
1037
+ # down project hidden_states if necessary
1038
+ if self.proj is not None and self.proj_layer_norm is not None:
1039
+ hidden_states = self.proj(hidden_states)
1040
+ hidden_states = self.proj_layer_norm(hidden_states)
1041
+
1042
+ hidden_states = hidden_states.transpose(1, 2)
1043
+
1044
+ for layer in self.layers:
1045
+ layerdrop_prob = np.random.random()
1046
+ if not self.training or (layerdrop_prob > self.layerdrop):
1047
+ hidden_states = layer(hidden_states)
1048
+
1049
+ hidden_states = hidden_states.transpose(1, 2)
1050
+ return hidden_states
1051
+
1052
+
1053
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->Wav2Vec2Conformer
1054
+ class Wav2Vec2ConformerAdapterLayer(nn.Module):
1055
+ def __init__(self, config):
1056
+ super().__init__()
1057
+ self.conv = nn.Conv1d(
1058
+ config.output_hidden_size,
1059
+ 2 * config.output_hidden_size,
1060
+ config.adapter_kernel_size,
1061
+ stride=config.adapter_stride,
1062
+ padding=1,
1063
+ )
1064
+
1065
+ def forward(self, hidden_states):
1066
+ hidden_states = self.conv(hidden_states)
1067
+ hidden_states = nn.functional.glu(hidden_states, dim=1)
1068
+
1069
+ return hidden_states
1070
+
1071
+
1072
+ class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
1073
+ """
1074
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1075
+ models.
1076
+ """
1077
+
1078
+ config_class = Wav2Vec2ConformerConfig
1079
+ base_model_prefix = "wav2vec2_conformer"
1080
+ main_input_name = "input_values"
1081
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1082
+ supports_gradient_checkpointing = True
1083
+
1084
+ def _init_weights(self, module):
1085
+ """Initialize the weights"""
1086
+ # Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
1087
+ if isinstance(module, Wav2Vec2ConformerForPreTraining):
1088
+ module.project_hid.reset_parameters()
1089
+ module.project_q.reset_parameters()
1090
+ module.project_hid._is_hf_initialized = True
1091
+ module.project_q._is_hf_initialized = True
1092
+ # gumbel softmax requires special init
1093
+ elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer):
1094
+ module.weight_proj.weight.data.normal_(mean=0.0, std=1)
1095
+ module.weight_proj.bias.data.zero_()
1096
+ nn.init.uniform_(module.codevectors)
1097
+ elif isinstance(module, Wav2Vec2ConformerSelfAttention):
1098
+ if hasattr(module, "pos_bias_u"):
1099
+ nn.init.xavier_uniform_(module.pos_bias_u)
1100
+ if hasattr(module, "pos_bias_v"):
1101
+ nn.init.xavier_uniform_(module.pos_bias_v)
1102
+ elif isinstance(module, Wav2Vec2ConformerPositionalConvEmbedding):
1103
+ nn.init.normal_(
1104
+ module.conv.weight,
1105
+ mean=0,
1106
+ std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
1107
+ )
1108
+ nn.init.constant_(module.conv.bias, 0)
1109
+ elif isinstance(module, Wav2Vec2ConformerFeatureProjection):
1110
+ k = math.sqrt(1 / module.projection.in_features)
1111
+ nn.init.uniform_(module.projection.weight, a=-k, b=k)
1112
+ nn.init.uniform_(module.projection.bias, a=-k, b=k)
1113
+ elif isinstance(module, nn.Linear):
1114
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1115
+
1116
+ if module.bias is not None:
1117
+ module.bias.data.zero_()
1118
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
1119
+ module.bias.data.zero_()
1120
+ module.weight.data.fill_(1.0)
1121
+ elif isinstance(module, nn.Conv1d):
1122
+ nn.init.kaiming_normal_(module.weight)
1123
+
1124
+ if module.bias is not None:
1125
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
1126
+ nn.init.uniform_(module.bias, a=-k, b=k)
1127
+
1128
+ def _get_feat_extract_output_lengths(
1129
+ self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
1130
+ ):
1131
+ """
1132
+ Computes the output length of the convolutional layers
1133
+ """
1134
+
1135
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
1136
+
1137
+ def _conv_out_length(input_length, kernel_size, stride):
1138
+ # 1D convolutional layer output length formula taken
1139
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
1140
+ return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
1141
+
1142
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
1143
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
1144
+
1145
+ if add_adapter:
1146
+ for _ in range(self.config.num_adapter_layers):
1147
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
1148
+
1149
+ return input_lengths
1150
+
1151
+ def _get_feature_vector_attention_mask(
1152
+ self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
1153
+ ):
1154
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
1155
+ # on inference mode.
1156
+ non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
1157
+
1158
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
1159
+ output_lengths = output_lengths.to(torch.long)
1160
+
1161
+ batch_size = attention_mask.shape[0]
1162
+
1163
+ attention_mask = torch.zeros(
1164
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
1165
+ )
1166
+ # these two operations makes sure that all values before the output lengths idxs are attended to
1167
+ attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
1168
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
1169
+ return attention_mask
1170
+
1171
+ def _set_gradient_checkpointing(self, module, value=False):
1172
+ if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)):
1173
+ module.gradient_checkpointing = value
1174
+
1175
+
1176
+ WAV2VEC2_CONFORMER_START_DOCSTRING = r"""
1177
+ Wav2Vec2Conformer was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
1178
+ Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
1179
+ Auli.
1180
+
1181
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1182
+ library implements for all its model (such as downloading or saving etc.).
1183
+
1184
+ This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a
1185
+ regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
1186
+
1187
+ Parameters:
1188
+ config ([`Wav2Vec2ConformerConfig`]): Model configuration class with all the parameters of the model.
1189
+ Initializing with a config file does not load the weights associated with the model, only the
1190
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1191
+ """
1192
+
1193
+
1194
+ WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r"""
1195
+ Args:
1196
+ input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
1197
+ Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
1198
+ into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
1199
+ soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
1200
+ conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
1201
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1202
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
1203
+ 1]`:
1204
+
1205
+ - 1 for tokens that are **not masked**,
1206
+ - 0 for tokens that are **masked**.
1207
+
1208
+ [What are attention masks?](../glossary#attention-mask)
1209
+
1210
+ <Tip warning={true}>
1211
+
1212
+ `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
1213
+ True`. For all models whose processor has `config.return_attention_mask == False`, such as
1214
+ [wav2vec2-conformer-rel-pos-large](https://huggingface.co/facebook/wav2vec2-conformer-rel-pos-large),
1215
+ `attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For
1216
+ such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware
1217
+ that these models also yield slightly different results depending on whether `input_values` is padded or
1218
+ not.
1219
+
1220
+ </Tip>
1221
+
1222
+ output_attentions (`bool`, *optional*):
1223
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1224
+ tensors for more detail.
1225
+ output_hidden_states (`bool`, *optional*):
1226
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1227
+ more detail.
1228
+ return_dict (`bool`, *optional*):
1229
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1230
+ """
1231
+
1232
+
1233
+ @add_start_docstrings(
1234
+ "The bare Wav2Vec2Conformer Model transformer outputting raw hidden-states without any specific head on top.",
1235
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
1236
+ )
1237
+ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
1238
+ def __init__(self, config: Wav2Vec2ConformerConfig):
1239
+ super().__init__(config)
1240
+ self.config = config
1241
+ self.feature_extractor = Wav2Vec2ConformerFeatureEncoder(config)
1242
+ self.feature_projection = Wav2Vec2ConformerFeatureProjection(config)
1243
+
1244
+ # model only needs masking vector if mask prob is > 0.0
1245
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
1246
+ self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
1247
+
1248
+ self.encoder = Wav2Vec2ConformerEncoder(config)
1249
+
1250
+ self.adapter = Wav2Vec2ConformerAdapter(config) if config.add_adapter else None
1251
+
1252
+ # Initialize weights and apply final processing
1253
+ self.post_init()
1254
+
1255
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.freeze_feature_encoder
1256
+ def freeze_feature_encoder(self):
1257
+ """
1258
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1259
+ not be updated during training.
1260
+ """
1261
+ self.feature_extractor._freeze_parameters()
1262
+
1263
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
1264
+ def _mask_hidden_states(
1265
+ self,
1266
+ hidden_states: torch.FloatTensor,
1267
+ mask_time_indices: Optional[torch.FloatTensor] = None,
1268
+ attention_mask: Optional[torch.LongTensor] = None,
1269
+ ):
1270
+ """
1271
+ Masks extracted features along time axis and/or along feature axis according to
1272
+ [SpecAugment](https://arxiv.org/abs/1904.08779).
1273
+ """
1274
+
1275
+ # `config.apply_spec_augment` can set masking to False
1276
+ if not getattr(self.config, "apply_spec_augment", True):
1277
+ return hidden_states
1278
+
1279
+ # generate indices & apply SpecAugment along time axis
1280
+ batch_size, sequence_length, hidden_size = hidden_states.size()
1281
+
1282
+ if mask_time_indices is not None:
1283
+ # apply SpecAugment along time axis with given mask_time_indices
1284
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
1285
+ elif self.config.mask_time_prob > 0 and self.training:
1286
+ mask_time_indices = _compute_mask_indices(
1287
+ (batch_size, sequence_length),
1288
+ mask_prob=self.config.mask_time_prob,
1289
+ mask_length=self.config.mask_time_length,
1290
+ attention_mask=attention_mask,
1291
+ min_masks=self.config.mask_time_min_masks,
1292
+ )
1293
+ mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
1294
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
1295
+
1296
+ if self.config.mask_feature_prob > 0 and self.training:
1297
+ # generate indices & apply SpecAugment along feature axis
1298
+ mask_feature_indices = _compute_mask_indices(
1299
+ (batch_size, hidden_size),
1300
+ mask_prob=self.config.mask_feature_prob,
1301
+ mask_length=self.config.mask_feature_length,
1302
+ min_masks=self.config.mask_feature_min_masks,
1303
+ )
1304
+ mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
1305
+ mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
1306
+ hidden_states[mask_feature_indices] = 0
1307
+
1308
+ return hidden_states
1309
+
1310
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
1311
+ @add_code_sample_docstrings(
1312
+ checkpoint=_CHECKPOINT_FOR_DOC,
1313
+ output_type=Wav2Vec2BaseModelOutput,
1314
+ config_class=_CONFIG_FOR_DOC,
1315
+ modality="audio",
1316
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
1317
+ )
1318
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.forward with wav2vec2->wav2vec2_conformer
1319
+ def forward(
1320
+ self,
1321
+ input_values: Optional[torch.Tensor],
1322
+ attention_mask: Optional[torch.Tensor] = None,
1323
+ mask_time_indices: Optional[torch.FloatTensor] = None,
1324
+ output_attentions: Optional[bool] = None,
1325
+ output_hidden_states: Optional[bool] = None,
1326
+ return_dict: Optional[bool] = None,
1327
+ ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
1328
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1329
+ output_hidden_states = (
1330
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1331
+ )
1332
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1333
+
1334
+ extract_features = self.feature_extractor(input_values)
1335
+ extract_features = extract_features.transpose(1, 2)
1336
+
1337
+ if attention_mask is not None:
1338
+ # compute reduced attention_mask corresponding to feature vectors
1339
+ attention_mask = self._get_feature_vector_attention_mask(
1340
+ extract_features.shape[1], attention_mask, add_adapter=False
1341
+ )
1342
+
1343
+ hidden_states, extract_features = self.feature_projection(extract_features)
1344
+ hidden_states = self._mask_hidden_states(
1345
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
1346
+ )
1347
+
1348
+ encoder_outputs = self.encoder(
1349
+ hidden_states,
1350
+ attention_mask=attention_mask,
1351
+ output_attentions=output_attentions,
1352
+ output_hidden_states=output_hidden_states,
1353
+ return_dict=return_dict,
1354
+ )
1355
+
1356
+ hidden_states = encoder_outputs[0]
1357
+
1358
+ if self.adapter is not None:
1359
+ hidden_states = self.adapter(hidden_states)
1360
+
1361
+ if not return_dict:
1362
+ return (hidden_states, extract_features) + encoder_outputs[1:]
1363
+
1364
+ return Wav2Vec2BaseModelOutput(
1365
+ last_hidden_state=hidden_states,
1366
+ extract_features=extract_features,
1367
+ hidden_states=encoder_outputs.hidden_states,
1368
+ attentions=encoder_outputs.attentions,
1369
+ )
1370
+
1371
+
1372
+ @add_start_docstrings(
1373
+ """Wav2Vec2Conformer Model with a quantizer and `VQ` head on top.""", WAV2VEC2_CONFORMER_START_DOCSTRING
1374
+ )
1375
+ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
1376
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
1377
+ def __init__(self, config: Wav2Vec2ConformerConfig):
1378
+ super().__init__(config)
1379
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
1380
+ self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
1381
+
1382
+ self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config)
1383
+
1384
+ self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
1385
+ self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
1386
+
1387
+ # Initialize weights and apply final processing
1388
+ self.post_init()
1389
+
1390
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature
1391
+ def set_gumbel_temperature(self, temperature: int):
1392
+ """
1393
+ Set the Gumbel softmax temperature to a given value. Only necessary for training
1394
+ """
1395
+ self.quantizer.temperature = temperature
1396
+
1397
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
1398
+ def freeze_feature_encoder(self):
1399
+ """
1400
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1401
+ not be updated during training.
1402
+ """
1403
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
1404
+
1405
+ @staticmethod
1406
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.compute_contrastive_logits
1407
+ def compute_contrastive_logits(
1408
+ target_features: torch.FloatTensor,
1409
+ negative_features: torch.FloatTensor,
1410
+ predicted_features: torch.FloatTensor,
1411
+ temperature: int = 0.1,
1412
+ ):
1413
+ """
1414
+ Compute logits for contrastive loss based using cosine similarity as the distance measure between
1415
+ `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
1416
+ """
1417
+ target_features = torch.cat([target_features, negative_features], dim=0)
1418
+
1419
+ logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as(
1420
+ target_features
1421
+ )
1422
+
1423
+ # apply temperature
1424
+ logits = logits / temperature
1425
+ return logits
1426
+
1427
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
1428
+ @replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
1429
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large
1430
+ def forward(
1431
+ self,
1432
+ input_values: Optional[torch.Tensor],
1433
+ attention_mask: Optional[torch.Tensor] = None,
1434
+ mask_time_indices: Optional[torch.BoolTensor] = None,
1435
+ sampled_negative_indices: Optional[torch.BoolTensor] = None,
1436
+ output_attentions: Optional[bool] = None,
1437
+ output_hidden_states: Optional[bool] = None,
1438
+ return_dict: Optional[bool] = None,
1439
+ ) -> Union[Tuple, Wav2Vec2ConformerForPreTrainingOutput]:
1440
+ r"""
1441
+ mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
1442
+ Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
1443
+ masked extracted features in *config.proj_codevector_dim* space.
1444
+ sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*):
1445
+ Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
1446
+ Required input for pre-training.
1447
+
1448
+ Returns:
1449
+
1450
+ Example:
1451
+
1452
+ ```python
1453
+ >>> import torch
1454
+ >>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
1455
+ >>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
1456
+ ... _compute_mask_indices,
1457
+ ... _sample_negative_indices,
1458
+ ... )
1459
+ >>> from datasets import load_dataset
1460
+
1461
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
1462
+ >>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
1463
+
1464
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1465
+ >>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
1466
+
1467
+ >>> # compute masked indices
1468
+ >>> batch_size, raw_sequence_length = input_values.shape
1469
+ >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
1470
+ >>> mask_time_indices = _compute_mask_indices(
1471
+ ... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
1472
+ ... )
1473
+ >>> sampled_negative_indices = _sample_negative_indices(
1474
+ ... features_shape=(batch_size, sequence_length),
1475
+ ... num_negatives=model.config.num_negatives,
1476
+ ... mask_time_indices=mask_time_indices,
1477
+ ... )
1478
+ >>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
1479
+ >>> sampled_negative_indices = torch.tensor(
1480
+ ... data=sampled_negative_indices, device=input_values.device, dtype=torch.long
1481
+ ... )
1482
+
1483
+ >>> with torch.no_grad():
1484
+ ... outputs = model(input_values, mask_time_indices=mask_time_indices)
1485
+
1486
+ >>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
1487
+ >>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
1488
+
1489
+ >>> # show that cosine similarity is much higher than random
1490
+ >>> cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5
1491
+ tensor(True)
1492
+
1493
+ >>> # for contrastive loss training model should be put into train mode
1494
+ >>> model = model.train()
1495
+ >>> loss = model(
1496
+ ... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
1497
+ ... ).loss
1498
+ ```"""
1499
+
1500
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1501
+
1502
+ if mask_time_indices is not None:
1503
+ mask_time_indices = mask_time_indices.to(torch.bool)
1504
+
1505
+ outputs = self.wav2vec2_conformer(
1506
+ input_values,
1507
+ attention_mask=attention_mask,
1508
+ output_attentions=output_attentions,
1509
+ output_hidden_states=output_hidden_states,
1510
+ mask_time_indices=mask_time_indices,
1511
+ return_dict=return_dict,
1512
+ )
1513
+
1514
+ # 1. project all transformed features (including masked) to final vq dim
1515
+ transformer_features = self.project_hid(outputs[0])
1516
+
1517
+ # 2. quantize all (unmasked) extracted features and project to final vq dim
1518
+ extract_features = self.dropout_features(outputs[1])
1519
+
1520
+ if attention_mask is not None:
1521
+ # compute reduced attention_mask correponding to feature vectors
1522
+ attention_mask = self._get_feature_vector_attention_mask(
1523
+ extract_features.shape[1], attention_mask, add_adapter=False
1524
+ )
1525
+
1526
+ quantized_features, codevector_perplexity = self.quantizer(
1527
+ extract_features, mask_time_indices=mask_time_indices
1528
+ )
1529
+ quantized_features = self.project_q(quantized_features)
1530
+
1531
+ loss = contrastive_loss = diversity_loss = None
1532
+ if sampled_negative_indices is not None:
1533
+ batch_size, sequence_length, hidden_size = quantized_features.shape
1534
+
1535
+ # for training, we sample negatives
1536
+ # 3. sample K negatives (distractors) quantized states for contrastive loss
1537
+ # if attention_mask is passed, make sure that padded feature vectors cannot be sampled
1538
+ # sample negative quantized vectors BTC => (BxT)C
1539
+ negative_quantized_features = quantized_features.view(-1, hidden_size)[
1540
+ sampled_negative_indices.long().view(-1)
1541
+ ]
1542
+ negative_quantized_features = negative_quantized_features.view(
1543
+ batch_size, sequence_length, -1, hidden_size
1544
+ ).permute(2, 0, 1, 3)
1545
+
1546
+ # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
1547
+ # of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
1548
+ logits = self.compute_contrastive_logits(
1549
+ quantized_features[None, :],
1550
+ negative_quantized_features,
1551
+ transformer_features,
1552
+ self.config.contrastive_logits_temperature,
1553
+ )
1554
+
1555
+ # 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
1556
+ # its cosine similarity will be masked
1557
+ neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
1558
+
1559
+ if neg_is_pos.any():
1560
+ logits[1:][neg_is_pos] = float("-inf")
1561
+
1562
+ # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
1563
+ # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
1564
+ logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
1565
+ target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
1566
+
1567
+ contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum")
1568
+ # 7. compute diversity loss: \mathbf{L}_d
1569
+ num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
1570
+ diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
1571
+
1572
+ # 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
1573
+ loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
1574
+
1575
+ if not return_dict:
1576
+ if loss is not None:
1577
+ return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
1578
+ return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
1579
+
1580
+ return Wav2Vec2ConformerForPreTrainingOutput(
1581
+ loss=loss,
1582
+ projected_states=transformer_features,
1583
+ projected_quantized_states=quantized_features,
1584
+ codevector_perplexity=codevector_perplexity,
1585
+ hidden_states=outputs.hidden_states,
1586
+ attentions=outputs.attentions,
1587
+ contrastive_loss=contrastive_loss,
1588
+ diversity_loss=diversity_loss,
1589
+ )
1590
+
1591
+
1592
+ @add_start_docstrings(
1593
+ """Wav2Vec2Conformer Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
1594
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
1595
+ )
1596
+ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
1597
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
1598
+ def __init__(self, config):
1599
+ super().__init__(config)
1600
+
1601
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
1602
+ self.dropout = nn.Dropout(config.final_dropout)
1603
+
1604
+ if config.vocab_size is None:
1605
+ raise ValueError(
1606
+ f"You are trying to instantiate {self.__class__} with a configuration that "
1607
+ "does not define the vocabulary size of the language model head. Please "
1608
+ "instantiate the model as follows: `Wav2Vec2ConformerForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
1609
+ "or define `vocab_size` of your model's configuration."
1610
+ )
1611
+ output_hidden_size = (
1612
+ config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
1613
+ )
1614
+ self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
1615
+
1616
+ # Initialize weights and apply final processing
1617
+ self.post_init()
1618
+
1619
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
1620
+ def freeze_feature_encoder(self):
1621
+ """
1622
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1623
+ not be updated during training.
1624
+ """
1625
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
1626
+
1627
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
1628
+ @add_code_sample_docstrings(
1629
+ checkpoint=_CHECKPOINT_FOR_DOC,
1630
+ output_type=CausalLMOutput,
1631
+ config_class=_CONFIG_FOR_DOC,
1632
+ expected_output=_CTC_EXPECTED_OUTPUT,
1633
+ expected_loss=_CTC_EXPECTED_LOSS,
1634
+ )
1635
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
1636
+ def forward(
1637
+ self,
1638
+ input_values: Optional[torch.Tensor],
1639
+ attention_mask: Optional[torch.Tensor] = None,
1640
+ output_attentions: Optional[bool] = None,
1641
+ output_hidden_states: Optional[bool] = None,
1642
+ return_dict: Optional[bool] = None,
1643
+ labels: Optional[torch.Tensor] = None,
1644
+ ) -> Union[Tuple, CausalLMOutput]:
1645
+ r"""
1646
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
1647
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
1648
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
1649
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
1650
+ config.vocab_size - 1]`.
1651
+ """
1652
+
1653
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1654
+
1655
+ outputs = self.wav2vec2_conformer(
1656
+ input_values,
1657
+ attention_mask=attention_mask,
1658
+ output_attentions=output_attentions,
1659
+ output_hidden_states=output_hidden_states,
1660
+ return_dict=return_dict,
1661
+ )
1662
+
1663
+ hidden_states = outputs[0]
1664
+ hidden_states = self.dropout(hidden_states)
1665
+
1666
+ logits = self.lm_head(hidden_states)
1667
+
1668
+ loss = None
1669
+ if labels is not None:
1670
+ if labels.max() >= self.config.vocab_size:
1671
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
1672
+
1673
+ # retrieve loss input_lengths from attention_mask
1674
+ attention_mask = (
1675
+ attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
1676
+ )
1677
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
1678
+
1679
+ # assuming that padded tokens are filled with -100
1680
+ # when not being attended to
1681
+ labels_mask = labels >= 0
1682
+ target_lengths = labels_mask.sum(-1)
1683
+ flattened_targets = labels.masked_select(labels_mask)
1684
+
1685
+ # ctc_loss doesn't support fp16
1686
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
1687
+
1688
+ with torch.backends.cudnn.flags(enabled=False):
1689
+ loss = nn.functional.ctc_loss(
1690
+ log_probs,
1691
+ flattened_targets,
1692
+ input_lengths,
1693
+ target_lengths,
1694
+ blank=self.config.pad_token_id,
1695
+ reduction=self.config.ctc_loss_reduction,
1696
+ zero_infinity=self.config.ctc_zero_infinity,
1697
+ )
1698
+
1699
+ if not return_dict:
1700
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
1701
+ return ((loss,) + output) if loss is not None else output
1702
+
1703
+ return CausalLMOutput(
1704
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
1705
+ )
1706
+
1707
+
1708
+ @add_start_docstrings(
1709
+ """
1710
+ Wav2Vec2Conformer Model with a sequence classification head on top (a linear layer over the pooled output) for
1711
+ tasks like SUPERB Keyword Spotting.
1712
+ """,
1713
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
1714
+ )
1715
+ class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedModel):
1716
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
1717
+ def __init__(self, config):
1718
+ super().__init__(config)
1719
+
1720
+ if hasattr(config, "add_adapter") and config.add_adapter:
1721
+ raise ValueError(
1722
+ "Sequence classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
1723
+ )
1724
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
1725
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
1726
+ if config.use_weighted_layer_sum:
1727
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
1728
+ self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
1729
+ self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
1730
+
1731
+ # Initialize weights and apply final processing
1732
+ self.post_init()
1733
+
1734
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
1735
+ def freeze_feature_encoder(self):
1736
+ """
1737
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1738
+ not be updated during training.
1739
+ """
1740
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
1741
+
1742
+ def freeze_base_model(self):
1743
+ """
1744
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
1745
+ be updated during training. Only the classification head will be updated.
1746
+ """
1747
+ for param in self.wav2vec2_conformer.parameters():
1748
+ param.requires_grad = False
1749
+
1750
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
1751
+ @add_code_sample_docstrings(
1752
+ checkpoint=_CHECKPOINT_FOR_DOC,
1753
+ output_type=SequenceClassifierOutput,
1754
+ config_class=_CONFIG_FOR_DOC,
1755
+ modality="audio",
1756
+ )
1757
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
1758
+ def forward(
1759
+ self,
1760
+ input_values: Optional[torch.Tensor],
1761
+ attention_mask: Optional[torch.Tensor] = None,
1762
+ output_attentions: Optional[bool] = None,
1763
+ output_hidden_states: Optional[bool] = None,
1764
+ return_dict: Optional[bool] = None,
1765
+ labels: Optional[torch.Tensor] = None,
1766
+ ) -> Union[Tuple, SequenceClassifierOutput]:
1767
+ r"""
1768
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1769
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1770
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1771
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1772
+ """
1773
+
1774
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1775
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
1776
+
1777
+ outputs = self.wav2vec2_conformer(
1778
+ input_values,
1779
+ attention_mask=attention_mask,
1780
+ output_attentions=output_attentions,
1781
+ output_hidden_states=output_hidden_states,
1782
+ return_dict=return_dict,
1783
+ )
1784
+
1785
+ if self.config.use_weighted_layer_sum:
1786
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
1787
+ hidden_states = torch.stack(hidden_states, dim=1)
1788
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
1789
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
1790
+ else:
1791
+ hidden_states = outputs[0]
1792
+
1793
+ hidden_states = self.projector(hidden_states)
1794
+ if attention_mask is None:
1795
+ pooled_output = hidden_states.mean(dim=1)
1796
+ else:
1797
+ padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
1798
+ hidden_states[~padding_mask] = 0.0
1799
+ pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
1800
+
1801
+ logits = self.classifier(pooled_output)
1802
+
1803
+ loss = None
1804
+ if labels is not None:
1805
+ loss_fct = CrossEntropyLoss()
1806
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
1807
+
1808
+ if not return_dict:
1809
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
1810
+ return ((loss,) + output) if loss is not None else output
1811
+
1812
+ return SequenceClassifierOutput(
1813
+ loss=loss,
1814
+ logits=logits,
1815
+ hidden_states=outputs.hidden_states,
1816
+ attentions=outputs.attentions,
1817
+ )
1818
+
1819
+
1820
+ @add_start_docstrings(
1821
+ """
1822
+ Wav2Vec2Conformer Model with a frame classification head on top for tasks like Speaker Diarization.
1823
+ """,
1824
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
1825
+ )
1826
+ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedModel):
1827
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
1828
+ def __init__(self, config):
1829
+ super().__init__(config)
1830
+
1831
+ if hasattr(config, "add_adapter") and config.add_adapter:
1832
+ raise ValueError(
1833
+ "Audio frame classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
1834
+ )
1835
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
1836
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
1837
+ if config.use_weighted_layer_sum:
1838
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
1839
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1840
+ self.num_labels = config.num_labels
1841
+
1842
+ self.init_weights()
1843
+
1844
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
1845
+ def freeze_feature_encoder(self):
1846
+ """
1847
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1848
+ not be updated during training.
1849
+ """
1850
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
1851
+
1852
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_base_model with wav2vec2->wav2vec2_conformer
1853
+ def freeze_base_model(self):
1854
+ """
1855
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
1856
+ be updated during training. Only the classification head will be updated.
1857
+ """
1858
+ for param in self.wav2vec2_conformer.parameters():
1859
+ param.requires_grad = False
1860
+
1861
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
1862
+ @add_code_sample_docstrings(
1863
+ checkpoint=_CHECKPOINT_FOR_DOC,
1864
+ output_type=TokenClassifierOutput,
1865
+ config_class=_CONFIG_FOR_DOC,
1866
+ modality="audio",
1867
+ )
1868
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->wav2vec2_conformer
1869
+ def forward(
1870
+ self,
1871
+ input_values: Optional[torch.Tensor],
1872
+ attention_mask: Optional[torch.Tensor] = None,
1873
+ labels: Optional[torch.Tensor] = None,
1874
+ output_attentions: Optional[bool] = None,
1875
+ output_hidden_states: Optional[bool] = None,
1876
+ return_dict: Optional[bool] = None,
1877
+ ) -> Union[Tuple, TokenClassifierOutput]:
1878
+ r"""
1879
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1880
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1881
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1882
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1883
+ """
1884
+
1885
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1886
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
1887
+
1888
+ outputs = self.wav2vec2_conformer(
1889
+ input_values,
1890
+ attention_mask=attention_mask,
1891
+ output_attentions=output_attentions,
1892
+ output_hidden_states=output_hidden_states,
1893
+ return_dict=return_dict,
1894
+ )
1895
+
1896
+ if self.config.use_weighted_layer_sum:
1897
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
1898
+ hidden_states = torch.stack(hidden_states, dim=1)
1899
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
1900
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
1901
+ else:
1902
+ hidden_states = outputs[0]
1903
+
1904
+ logits = self.classifier(hidden_states)
1905
+
1906
+ loss = None
1907
+ if labels is not None:
1908
+ loss_fct = CrossEntropyLoss()
1909
+ loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
1910
+
1911
+ if not return_dict:
1912
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
1913
+ return output
1914
+
1915
+ return TokenClassifierOutput(
1916
+ loss=loss,
1917
+ logits=logits,
1918
+ hidden_states=outputs.hidden_states,
1919
+ attentions=outputs.attentions,
1920
+ )
1921
+
1922
+
1923
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
1924
+ class AMSoftmaxLoss(nn.Module):
1925
+ def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
1926
+ super(AMSoftmaxLoss, self).__init__()
1927
+ self.scale = scale
1928
+ self.margin = margin
1929
+ self.num_labels = num_labels
1930
+ self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
1931
+ self.loss = nn.CrossEntropyLoss()
1932
+
1933
+ def forward(self, hidden_states, labels):
1934
+ labels = labels.flatten()
1935
+ weight = nn.functional.normalize(self.weight, dim=0)
1936
+ hidden_states = nn.functional.normalize(hidden_states, dim=1)
1937
+ cos_theta = torch.mm(hidden_states, weight)
1938
+ psi = cos_theta - self.margin
1939
+
1940
+ onehot = nn.functional.one_hot(labels, self.num_labels)
1941
+ logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
1942
+ loss = self.loss(logits, labels)
1943
+
1944
+ return loss
1945
+
1946
+
1947
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
1948
+ class TDNNLayer(nn.Module):
1949
+ def __init__(self, config, layer_id=0):
1950
+ super().__init__()
1951
+ self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
1952
+ self.out_conv_dim = config.tdnn_dim[layer_id]
1953
+ self.kernel_size = config.tdnn_kernel[layer_id]
1954
+ self.dilation = config.tdnn_dilation[layer_id]
1955
+
1956
+ self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
1957
+ self.activation = nn.ReLU()
1958
+
1959
+ def forward(self, hidden_states):
1960
+ hidden_states = hidden_states.unsqueeze(1)
1961
+ hidden_states = nn.functional.unfold(
1962
+ hidden_states,
1963
+ (self.kernel_size, self.in_conv_dim),
1964
+ stride=(1, self.in_conv_dim),
1965
+ dilation=(self.dilation, 1),
1966
+ )
1967
+ hidden_states = hidden_states.transpose(1, 2)
1968
+ hidden_states = self.kernel(hidden_states)
1969
+
1970
+ hidden_states = self.activation(hidden_states)
1971
+ return hidden_states
1972
+
1973
+
1974
+ @add_start_docstrings(
1975
+ """
1976
+ Wav2Vec2Conformer Model with an XVector feature extraction head on top for tasks like Speaker Verification.
1977
+ """,
1978
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
1979
+ )
1980
+ class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
1981
+ def __init__(self, config):
1982
+ super().__init__(config)
1983
+
1984
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
1985
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
1986
+ if config.use_weighted_layer_sum:
1987
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
1988
+ self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
1989
+
1990
+ tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
1991
+ self.tdnn = nn.ModuleList(tdnn_layers)
1992
+
1993
+ self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
1994
+ self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
1995
+
1996
+ self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
1997
+
1998
+ self.init_weights()
1999
+
2000
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
2001
+ def freeze_feature_encoder(self):
2002
+ """
2003
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
2004
+ not be updated during training.
2005
+ """
2006
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
2007
+
2008
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_base_model with wav2vec2->wav2vec2_conformer
2009
+ def freeze_base_model(self):
2010
+ """
2011
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
2012
+ be updated during training. Only the classification head will be updated.
2013
+ """
2014
+ for param in self.wav2vec2_conformer.parameters():
2015
+ param.requires_grad = False
2016
+
2017
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector._get_tdnn_output_lengths with wav2vec2->wav2vec2_conformer
2018
+ def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
2019
+ """
2020
+ Computes the output length of the TDNN layers
2021
+ """
2022
+
2023
+ def _conv_out_length(input_length, kernel_size, stride):
2024
+ # 1D convolutional layer output length formula taken
2025
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
2026
+ return (input_length - kernel_size) // stride + 1
2027
+
2028
+ for kernel_size in self.config.tdnn_kernel:
2029
+ input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
2030
+
2031
+ return input_lengths
2032
+
2033
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
2034
+ @add_code_sample_docstrings(
2035
+ checkpoint=_CHECKPOINT_FOR_DOC,
2036
+ output_type=XVectorOutput,
2037
+ config_class=_CONFIG_FOR_DOC,
2038
+ modality="audio",
2039
+ )
2040
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
2041
+ def forward(
2042
+ self,
2043
+ input_values: Optional[torch.Tensor],
2044
+ attention_mask: Optional[torch.Tensor] = None,
2045
+ output_attentions: Optional[bool] = None,
2046
+ output_hidden_states: Optional[bool] = None,
2047
+ return_dict: Optional[bool] = None,
2048
+ labels: Optional[torch.Tensor] = None,
2049
+ ) -> Union[Tuple, XVectorOutput]:
2050
+ r"""
2051
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
2052
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
2053
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
2054
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
2055
+ """
2056
+
2057
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2058
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
2059
+
2060
+ outputs = self.wav2vec2_conformer(
2061
+ input_values,
2062
+ attention_mask=attention_mask,
2063
+ output_attentions=output_attentions,
2064
+ output_hidden_states=output_hidden_states,
2065
+ return_dict=return_dict,
2066
+ )
2067
+
2068
+ if self.config.use_weighted_layer_sum:
2069
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
2070
+ hidden_states = torch.stack(hidden_states, dim=1)
2071
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
2072
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
2073
+ else:
2074
+ hidden_states = outputs[0]
2075
+
2076
+ hidden_states = self.projector(hidden_states)
2077
+
2078
+ for tdnn_layer in self.tdnn:
2079
+ hidden_states = tdnn_layer(hidden_states)
2080
+
2081
+ # Statistic Pooling
2082
+ if attention_mask is None:
2083
+ mean_features = hidden_states.mean(dim=1)
2084
+ std_features = hidden_states.std(dim=1)
2085
+ else:
2086
+ feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
2087
+ tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
2088
+ mean_features = []
2089
+ std_features = []
2090
+ for i, length in enumerate(tdnn_output_lengths):
2091
+ mean_features.append(hidden_states[i, :length].mean(dim=0))
2092
+ std_features.append(hidden_states[i, :length].std(dim=0))
2093
+ mean_features = torch.stack(mean_features)
2094
+ std_features = torch.stack(std_features)
2095
+ statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
2096
+
2097
+ output_embeddings = self.feature_extractor(statistic_pooling)
2098
+ logits = self.classifier(output_embeddings)
2099
+
2100
+ loss = None
2101
+ if labels is not None:
2102
+ loss = self.objective(logits, labels)
2103
+
2104
+ if not return_dict:
2105
+ output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
2106
+ return ((loss,) + output) if loss is not None else output
2107
+
2108
+ return XVectorOutput(
2109
+ loss=loss,
2110
+ logits=logits,
2111
+ embeddings=output_embeddings,
2112
+ hidden_states=outputs.hidden_states,
2113
+ attentions=outputs.attentions,
2114
+ )
musicfm/modules/random_quantizer.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright 2023 ByteDance Inc.
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
6
+ # to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
8
+ #
9
+ # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
10
+ #
11
+ # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
12
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
13
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
14
+ # IN THE SOFTWARE.
15
+
16
+ import torch
17
+ from torch import nn, einsum
18
+ from einops import rearrange
19
+
20
+
21
+ class RandomProjectionQuantizer(nn.Module):
22
+ """
23
+ Random projection and codebook lookup module
24
+
25
+ Some code is borrowed from:
26
+ https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/random_projection_quantizer.py
27
+ But I did normalization using pre-computed global mean & variance instead of using layer norm.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ input_dim,
33
+ codebook_dim,
34
+ codebook_size,
35
+ seed=142,
36
+ ):
37
+ super().__init__()
38
+
39
+ # random seed
40
+ torch.manual_seed(seed)
41
+
42
+ # randomly initialized projection
43
+ random_projection = torch.empty(input_dim, codebook_dim)
44
+ nn.init.xavier_normal_(random_projection)
45
+ self.register_buffer("random_projection", random_projection)
46
+
47
+ # randomly initialized codebook
48
+ codebook = torch.empty(codebook_size, codebook_dim)
49
+ nn.init.normal_(codebook)
50
+ self.register_buffer("codebook", codebook)
51
+
52
+ def codebook_lookup(self, x):
53
+ # reshape
54
+ b = x.shape[0]
55
+ x = rearrange(x, "b n e -> (b n) e")
56
+
57
+ # L2 normalization
58
+ normalized_x = nn.functional.normalize(x, dim=1, p=2)
59
+ normalized_codebook = nn.functional.normalize(self.codebook, dim=1, p=2)
60
+
61
+ # compute distances
62
+ distances = torch.cdist(normalized_codebook, normalized_x)
63
+
64
+ # get nearest
65
+ nearest_indices = torch.argmin(distances, dim=0)
66
+
67
+ # reshape
68
+ xq = rearrange(nearest_indices, "(b n) -> b n", b=b)
69
+
70
+ return xq
71
+
72
+ @torch.no_grad()
73
+ def forward(self, x):
74
+ # always eval
75
+ self.eval()
76
+
77
+ # random projection [batch, length, input_dim] -> [batch, length, codebook_dim]
78
+ x = einsum("b n d, d e -> b n e", x, self.random_projection)
79
+
80
+ # codebook lookup
81
+ xq = self.codebook_lookup(x)
82
+
83
+ return xq
postprocessing/functional.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains code adapted from the following sources:
2
+ # [MIT license] https://github.com/mir-aidj/all-in-one/blob/main/src/allin1/postprocessing/functional.py
3
+
4
+ import numpy as np
5
+ import torch
6
+ from .helpers import (
7
+ local_maxima,
8
+ peak_picking,
9
+ # event_frames_to_time,
10
+ )
11
+ from dataset.label2id import LABEL_TO_ID, ID_TO_LABEL
12
+ from dataset.custom_types import MsaInfo
13
+
14
+
15
+ def event_frames_to_time(frame_rates, boundary: np.array):
16
+ boundary = np.array(boundary)
17
+ boundary_times = boundary / frame_rates
18
+ return boundary_times
19
+
20
+
21
+ def postprocess_functional_structure(
22
+ logits,
23
+ config,
24
+ ):
25
+ # pdb.set_trace()
26
+ boundary_logits = logits["boundary_logits"]
27
+ function_logits = logits["function_logits"]
28
+
29
+ assert boundary_logits.shape[0] == 1 and function_logits.shape[0] == 1, (
30
+ "Only batch size 1 is supported"
31
+ )
32
+ raw_prob_sections = torch.sigmoid(boundary_logits[0])
33
+ raw_prob_functions = torch.softmax(function_logits[0].transpose(0, 1), dim=0)
34
+
35
+ # filter_size=4 * cfg.min_hops_per_beat + 1
36
+ prob_sections, _ = local_maxima(
37
+ raw_prob_sections, filter_size=config.local_maxima_filter_size
38
+ )
39
+ prob_sections = prob_sections.cpu().numpy()
40
+
41
+ prob_functions = raw_prob_functions.cpu().numpy()
42
+
43
+ boundary_candidates = peak_picking(
44
+ boundary_activation=prob_sections,
45
+ window_past=int(12 * config.frame_rates), # 原来是fps
46
+ window_future=int(12 * config.frame_rates),
47
+ )
48
+ boundary = boundary_candidates > 0.0
49
+
50
+ duration = len(prob_sections) / config.frame_rates
51
+ pred_boundary_times = event_frames_to_time(
52
+ frame_rates=config.frame_rates, boundary=np.flatnonzero(boundary)
53
+ )
54
+ if pred_boundary_times[0] != 0:
55
+ pred_boundary_times = np.insert(pred_boundary_times, 0, 0)
56
+ if pred_boundary_times[-1] != duration:
57
+ pred_boundary_times = np.append(pred_boundary_times, duration)
58
+ pred_boundaries = np.stack([pred_boundary_times[:-1], pred_boundary_times[1:]]).T
59
+
60
+ pred_boundary_indices = np.flatnonzero(boundary)
61
+ pred_boundary_indices = pred_boundary_indices[pred_boundary_indices > 0]
62
+ prob_segment_function = np.split(prob_functions, pred_boundary_indices, axis=1)
63
+ pred_labels = [p.mean(axis=1).argmax().item() for p in prob_segment_function]
64
+
65
+ segments: MsaInfo = []
66
+ for (start, end), label in zip(pred_boundaries, pred_labels):
67
+ segment = (float(start), str(ID_TO_LABEL[label]))
68
+ segments.append(segment)
69
+
70
+ segments.append((float(pred_boundary_times[-1]), "end"))
71
+ return segments
postprocessing/helpers.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains code adapted from the following sources:
2
+ # [MIT license] https://github.com/mir-aidj/all-in-one/blob/main/src/allin1/postprocessing/helpers.py
3
+
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+ import torch
7
+ import librosa
8
+ from typing import Union
9
+ from scipy.signal import argrelextrema
10
+ from scipy.interpolate import interp1d
11
+ from numpy.lib.stride_tricks import sliding_window_view
12
+ from numpy.typing import NDArray
13
+
14
+
15
+ def local_maxima(tensor, filter_size=41):
16
+ assert len(tensor.shape) in (1, 2), "Input tensor should have 1 or 2 dimensions"
17
+ assert filter_size % 2 == 1, "Filter size should be an odd number"
18
+
19
+ original_shape = tensor.shape
20
+ if len(original_shape) == 1:
21
+ tensor = tensor.unsqueeze(0)
22
+
23
+ # Pad the input array with the minimum value
24
+ padding = filter_size // 2
25
+ padded_arr = F.pad(tensor, (padding, padding), mode="constant", value=-torch.inf)
26
+
27
+ # Create a rolling window view of the padded array
28
+ rolling_view = padded_arr.unfold(1, filter_size, 1)
29
+
30
+ # Find the indices of the local maxima
31
+ center = filter_size // 2
32
+ local_maxima_mask = torch.eq(
33
+ rolling_view[:, :, center], torch.max(rolling_view, dim=-1).values
34
+ )
35
+ local_maxima_indices = local_maxima_mask.nonzero()
36
+
37
+ # Initialize a new PyTorch tensor with zeros and the same shape as the input tensor
38
+ output_arr = torch.zeros_like(tensor)
39
+
40
+ # Set the local maxima values in the output tensor
41
+ output_arr[local_maxima_mask] = tensor[local_maxima_mask]
42
+
43
+ output_arr = output_arr.reshape(original_shape)
44
+
45
+ return output_arr, local_maxima_indices
46
+
47
+
48
+ def local_maxima_numpy(arr, order=20):
49
+ is_batch = len(arr.shape) == 2
50
+ if is_batch:
51
+ return np.stack([local_maxima_numpy(x, order) for x in arr])
52
+
53
+ # Define a comparison function for argrelextrema to find local maxima
54
+ compare_func = np.greater
55
+
56
+ # Find the indices of the local maxima
57
+ local_maxima_indices = argrelextrema(arr, compare_func, order=order)
58
+
59
+ # Initialize a new numpy array with zeros and the same shape as the input array
60
+ output_arr = np.zeros_like(arr)
61
+
62
+ # Set the local maxima values in the output array
63
+ output_arr[local_maxima_indices] = arr[local_maxima_indices]
64
+
65
+ return output_arr
66
+
67
+
68
+ def peak_picking(boundary_activation, window_past=12, window_future=6):
69
+ # Find local maxima using a sliding window
70
+ window_size = window_past + window_future
71
+ assert window_size % 2 == 0, "window_past + window_future must be even"
72
+ window_size += 1
73
+
74
+ # Pad boundary_activation
75
+ boundary_activation_padded = np.pad(
76
+ boundary_activation, (window_past, window_future), mode="constant"
77
+ )
78
+ max_filter = sliding_window_view(boundary_activation_padded, window_size)
79
+ local_maxima = (boundary_activation == np.max(max_filter, axis=-1)) & (
80
+ boundary_activation > 0
81
+ )
82
+
83
+ # Compute strength values by subtracting the mean of the past and future windows
84
+ past_window_filter = sliding_window_view(
85
+ boundary_activation_padded[: -(window_future + 1)], window_past
86
+ )
87
+ future_window_filter = sliding_window_view(
88
+ boundary_activation_padded[window_past + 1 :], window_future
89
+ )
90
+ past_mean = np.mean(past_window_filter, axis=-1)
91
+ future_mean = np.mean(future_window_filter, axis=-1)
92
+ strength_values = boundary_activation - ((past_mean + future_mean) / 2)
93
+
94
+ # Get boundary candidates and their corresponding strength values
95
+ boundary_candidates = np.flatnonzero(local_maxima)
96
+ strength_values = strength_values[boundary_candidates]
97
+
98
+ strength_activations = np.zeros_like(boundary_activation)
99
+ strength_activations[boundary_candidates] = strength_values
100
+
101
+ return strength_activations