akanatas commited on
Commit
885d6e7
·
1 Parent(s): 77b7ebb

Add modeling_MERT.py

Browse files
Files changed (1) hide show
  1. modeling_MERT.py +408 -0
modeling_MERT.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MERT model definition.
3
+
4
+ Adapted from: https://github.com/yizhilll/MERT/blob/main/scripts/mert_hf/modeling_MERT.py
5
+ """
6
+
7
+ from typing import Optional, Tuple, Union
8
+ from transformers.modeling_outputs import BaseModelOutput
9
+ import torch
10
+ from torch import nn
11
+
12
+ from transformers.models.hubert.modeling_hubert import (
13
+ HubertFeatureEncoder,
14
+ HubertModel,
15
+ HubertEncoderStableLayerNorm,
16
+ HubertEncoder,
17
+ HubertEncoderLayer,
18
+ HubertPositionalConvEmbedding,
19
+ HubertAttention,
20
+ HubertFeedForward,
21
+ )
22
+
23
+ try:
24
+ from nnAudio import features as nnAudioFeatures
25
+ NNAUDIO_INSTALLED=True
26
+ except:
27
+ print("WARNING: feature_extractor_cqt requires the libray 'nnAudio'")
28
+ NNAUDIO_INSTALLED=False
29
+
30
+ from .configuration_MERT import MERTConfig
31
+
32
+ class MERTFeatureProjection(nn.Module):
33
+ def __init__(self, config):
34
+ super().__init__()
35
+ self.feat_proj_layer_norm = config.feat_proj_layer_norm
36
+ self.feature_extractor_cqt = config.feature_extractor_cqt
37
+
38
+ if self.feature_extractor_cqt:
39
+ # v3 concat features
40
+ self.feature_dimension = config.conv_dim[-1] + config.feature_extractor_cqt_bins
41
+ print(f"feature dimention: {self.feature_dimension}")
42
+ else:
43
+ self.feature_dimension = config.conv_dim[-1]
44
+ if self.feat_proj_layer_norm:
45
+ self.layer_norm = nn.LayerNorm(self.feature_dimension, eps=config.layer_norm_eps)
46
+ self.projection = nn.Linear(self.feature_dimension, config.hidden_size)
47
+ self.dropout = nn.Dropout(config.feat_proj_dropout)
48
+
49
+ def forward(self, hidden_states):
50
+ # non-projected hidden states are needed for quantization
51
+ if self.feat_proj_layer_norm:
52
+ hidden_states = self.layer_norm(hidden_states)
53
+ hidden_states = self.projection(hidden_states)
54
+ hidden_states = self.dropout(hidden_states)
55
+ return hidden_states
56
+
57
+ class MERTModel(HubertModel):
58
+ # overwrite config class
59
+ config_class = MERTConfig
60
+ base_model_prefix = "mert_model"
61
+ def __init__(
62
+ self,
63
+ config: MERTConfig,
64
+ ) -> None:
65
+ """
66
+ initialize the with the grandparent method HubertPreTrainedModel.__init__()
67
+ and modify the HuBERTModel.__init__()
68
+ """
69
+ super(HubertModel, self).__init__(config)
70
+
71
+ self.config = config
72
+
73
+ self.feature_extractor = HubertFeatureEncoder(config)
74
+ self.feature_projection = MERTFeatureProjection(config) # replace Feature Projection for introcuing new feature
75
+
76
+ if self.config.feature_extractor_cqt:
77
+ assert NNAUDIO_INSTALLED, "ERROR: feature_extractor_cqt requires the libray 'nnAudio', try after `pip install nnAudio` "
78
+ print('initializing cqt extractor for MERT')
79
+ self.feature_extractor_cqt = nnAudioFeatures.cqt.CQT(sr=self.config.sample_rate, hop_length=self.config.sample_rate//50, fmin=32.7,
80
+ fmax=None, n_bins=self.config.feature_extractor_cqt_bins, bins_per_octave=self.config.feature_extractor_cqt_bins//7,
81
+ filter_scale=1, norm=1, window='hann', center=True,
82
+ pad_mode='constant', trainable=False,
83
+ output_format='Magnitude', verbose=True)
84
+
85
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
86
+ self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
87
+
88
+
89
+ if config.do_stable_layer_norm:
90
+ assert not config.deepnorm, "must use post-layer_norm with deepnorm"
91
+ self.encoder = HubertEncoderStableLayerNorm(config)
92
+ else:
93
+ if config.deepnorm:
94
+ self.encoder = HubertEncoder_extend(config)
95
+ else:
96
+ self.encoder = HubertEncoder(config)
97
+
98
+ # Initialize weights and apply final processing
99
+ self.post_init()
100
+
101
+ def forward(self, input_values: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, mask_time_indices: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None) -> Union[Tuple, BaseModelOutput]:
102
+
103
+ # return super().forward(input_values, attention_mask, mask_time_indices, output_attentions, output_hidden_states, return_dict)
104
+
105
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
106
+ output_hidden_states = (
107
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
108
+ )
109
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
110
+
111
+ extract_features = self.feature_extractor(input_values)
112
+ extract_features = extract_features.transpose(1, 2)
113
+
114
+ # add additional cqt features for transformer input
115
+ if self.config.feature_extractor_cqt:
116
+ features_cqt = self.feature_extractor_cqt(input_values).transpose(1, 2)
117
+ features_cqt = features_cqt[:,:extract_features.shape[1],:] # align shape
118
+ # # v2
119
+ # features_cqt = self.post_cqt_feature_proj(features_cqt)
120
+ # extract_features = self.feature_projection.layer_norm(extract_features) + self.feature_projection.layer_norm(features_cqt) #v2
121
+ # v3
122
+ extract_features = torch.cat([extract_features,features_cqt], 2)
123
+
124
+ if attention_mask is not None:
125
+ # compute reduced attention_mask corresponding to feature vectors
126
+ attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
127
+
128
+ hidden_states = self.feature_projection(extract_features)
129
+ hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
130
+
131
+ encoder_outputs = self.encoder(
132
+ hidden_states,
133
+ attention_mask=attention_mask,
134
+ output_attentions=output_attentions,
135
+ output_hidden_states=output_hidden_states,
136
+ return_dict=return_dict,
137
+ )
138
+
139
+ hidden_states = encoder_outputs[0] # take last_hidden from encoder output
140
+
141
+ if not return_dict:
142
+ return (hidden_states,) + encoder_outputs[1:]
143
+
144
+ return BaseModelOutput(
145
+ last_hidden_state=hidden_states,
146
+ hidden_states=encoder_outputs.hidden_states,
147
+ attentions=encoder_outputs.attentions,
148
+ )
149
+
150
+
151
+ class HubertEncoder_extend(HubertEncoder):
152
+ def __init__(self, config):
153
+ # super().__init__()
154
+ # call nn module initialization
155
+ nn.Module.__init__(self)
156
+ # super(HubertEncoder_extend, self).__init__()
157
+
158
+ self.config = config
159
+ self.pos_conv_embed = HubertPositionalConvEmbedding(config)
160
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
161
+ self.dropout = nn.Dropout(config.hidden_dropout)
162
+
163
+
164
+ self.layers = nn.ModuleList([HubertEncoderLayerExtend(config) for _ in range(config.num_hidden_layers)])
165
+
166
+ self.gradient_checkpointing = False
167
+
168
+ if config.deepnorm:
169
+ import math
170
+ init_scale = math.pow(8.0 * config.num_hidden_layers, 0.25)
171
+ for name, p in self.named_parameters():
172
+ if (
173
+ "feed_forward.intermediate_dense" in name
174
+ or "feed_forward.output_dense" in name
175
+ or "out_proj" in name
176
+ or "v_proj" in name
177
+ ):
178
+ p.data.div_(init_scale)
179
+
180
+ class HubertEncoderLayerExtend(HubertEncoderLayer):
181
+ def __init__(self, config):
182
+ nn.Module.__init__(self)
183
+ # super(HubertEncoderLayerExtend, self).__init__()
184
+ if config.attention_relax > 0 :
185
+ self.attention = HubertAttention_extend(
186
+ embed_dim=config.hidden_size,
187
+ num_heads=config.num_attention_heads,
188
+ dropout=config.attention_dropout,
189
+ is_decoder=False,
190
+ attention_relax=config.attention_relax,
191
+ )
192
+ else:
193
+ self.attention = HubertAttention(
194
+ embed_dim=config.hidden_size,
195
+ num_heads=config.num_attention_heads,
196
+ dropout=config.attention_dropout,
197
+ is_decoder=False,
198
+ )
199
+ self.dropout = nn.Dropout(config.hidden_dropout)
200
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
201
+ self.feed_forward = HubertFeedForward(config)
202
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
203
+
204
+ if config.deepnorm:
205
+ import math
206
+ self.residual_alpha = math.pow(2.0 * config.num_hidden_layers, 0.25)
207
+ else:
208
+ self.residual_alpha = 1.0
209
+
210
+ def residual_connection(self, x, residual):
211
+ '''
212
+ residual: input before f()
213
+ x: output of f(residual)
214
+ '''
215
+ return residual * self.residual_alpha + x
216
+
217
+ def forward(self, hidden_states, attention_mask=None, output_attentions=False):
218
+ attn_residual = hidden_states
219
+ hidden_states, attn_weights, _ = self.attention(
220
+ hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
221
+ )
222
+ hidden_states = self.dropout(hidden_states)
223
+
224
+ # hidden_states = attn_residual + hidden_states
225
+ hidden_states = self.residual_connection(hidden_states, attn_residual)
226
+
227
+ hidden_states = self.layer_norm(hidden_states)
228
+
229
+ # hidden_states = hidden_states + self.feed_forward(hidden_states)
230
+ ffn_residual = hidden_states
231
+ hidden_states = self.feed_forward(hidden_states)
232
+ hidden_states = self.residual_connection(hidden_states, ffn_residual)
233
+
234
+ hidden_states = self.final_layer_norm(hidden_states)
235
+
236
+ outputs = (hidden_states,)
237
+
238
+ if output_attentions:
239
+ outputs += (attn_weights,)
240
+
241
+ return outputs
242
+
243
+
244
+ class HubertAttention_extend(nn.Module):
245
+ def __init__(
246
+ self,
247
+ embed_dim: int,
248
+ num_heads: int,
249
+ dropout: float = 0.0,
250
+ is_decoder: bool = False,
251
+ bias: bool = True,
252
+ attention_relax: float = -1.0,
253
+ ):
254
+ super().__init__()
255
+ # nn.Module.__init__(self)
256
+ self.embed_dim = embed_dim
257
+ self.num_heads = num_heads
258
+ self.dropout = dropout
259
+ self.head_dim = embed_dim // num_heads
260
+
261
+ if (self.head_dim * num_heads) != self.embed_dim:
262
+ raise ValueError(
263
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
264
+ f" and `num_heads`: {num_heads})."
265
+ )
266
+ self.scaling = self.head_dim**-0.5
267
+ self.is_decoder = is_decoder
268
+
269
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
270
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
271
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
272
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
273
+
274
+ if attention_relax > 0:
275
+ self.attention_relax = attention_relax
276
+
277
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
278
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
279
+
280
+ def forward(
281
+ self,
282
+ hidden_states: torch.Tensor,
283
+ key_value_states: Optional[torch.Tensor] = None,
284
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
285
+ attention_mask: Optional[torch.Tensor] = None,
286
+ layer_head_mask: Optional[torch.Tensor] = None,
287
+ output_attentions: bool = False,
288
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
289
+ """Input shape: Batch x Time x Channel"""
290
+
291
+ # if key_value_states are provided this layer is used as a cross-attention layer
292
+ # for the decoder
293
+ is_cross_attention = key_value_states is not None
294
+
295
+ bsz, tgt_len, _ = hidden_states.size()
296
+
297
+ # get query proj
298
+ query_states = self.q_proj(hidden_states) * self.scaling
299
+ # get key, value proj
300
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
301
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
302
+ # the provided `key_value_states` to support prefix tuning
303
+ if (
304
+ is_cross_attention
305
+ and past_key_value is not None
306
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
307
+ ):
308
+ # reuse k,v, cross_attentions
309
+ key_states = past_key_value[0]
310
+ value_states = past_key_value[1]
311
+ elif is_cross_attention:
312
+ # cross_attentions
313
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
314
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
315
+ elif past_key_value is not None:
316
+ # reuse k, v, self_attention
317
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
318
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
319
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
320
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
321
+ else:
322
+ # self_attention
323
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
324
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
325
+
326
+ if self.is_decoder:
327
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
328
+ # Further calls to cross_attention layer can then reuse all cross-attention
329
+ # key/value_states (first "if" case)
330
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
331
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
332
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
333
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
334
+ past_key_value = (key_states, value_states)
335
+
336
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
337
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
338
+ key_states = key_states.view(*proj_shape)
339
+ value_states = value_states.view(*proj_shape)
340
+
341
+ src_len = key_states.size(1)
342
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
343
+
344
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
345
+ raise ValueError(
346
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
347
+ f" {attn_weights.size()}"
348
+ )
349
+
350
+ if attention_mask is not None:
351
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
352
+ raise ValueError(
353
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
354
+ )
355
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
356
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
357
+
358
+ if self.attention_relax > 0:
359
+ # => (bsz, self.num_heads, tgt_len, src_len)
360
+ # attn_weights_relax = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)/self.attention_relax
361
+ # => (bsz*self.num_heads, tgt_len, src_len)
362
+ attn_weights_relax = attn_weights / self.attention_relax
363
+
364
+ # => (bsz* self.num_heads, tgt_len, 1)
365
+ attn_max_relax = torch.max(attn_weights_relax, dim=-1, keepdim=False).unsqueeze(2)
366
+ attn_weights = (attn_weights_relax - attn_max_relax) * self.attention_relax
367
+
368
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
369
+
370
+ if layer_head_mask is not None:
371
+ if layer_head_mask.size() != (self.num_heads,):
372
+ raise ValueError(
373
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
374
+ f" {layer_head_mask.size()}"
375
+ )
376
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
377
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
378
+
379
+ if output_attentions:
380
+ # this operation is a bit awkward, but it's required to
381
+ # make sure that attn_weights keeps its gradient.
382
+ # In order to do so, attn_weights have to be reshaped
383
+ # twice and have to be reused in the following
384
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
385
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
386
+ else:
387
+ attn_weights_reshaped = None
388
+
389
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
390
+
391
+ attn_output = torch.bmm(attn_probs, value_states)
392
+
393
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
394
+ raise ValueError(
395
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
396
+ f" {attn_output.size()}"
397
+ )
398
+
399
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
400
+ attn_output = attn_output.transpose(1, 2)
401
+
402
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
403
+ # partitioned aross GPUs when using tensor-parallelism.
404
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
405
+
406
+ attn_output = self.out_proj(attn_output)
407
+
408
+ return attn_output, attn_weights_reshaped, past_key_value