akanatas commited on
Commit
77b7ebb
·
1 Parent(s): c744921

Add configuration_MERT.py

Browse files
Files changed (1) hide show
  1. configuration_MERT.py +134 -0
configuration_MERT.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MERT model configuration.
3
+
4
+ Adapted from: https://github.com/yizhilll/MERT/blob/main/scripts/mert_hf/configuration_MERT.py
5
+ """
6
+
7
+ import functools
8
+ import operator
9
+
10
+ from transformers.configuration_utils import PretrainedConfig
11
+ from transformers.utils import logging
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+ class MERTConfig(PretrainedConfig):
16
+ r"""
17
+ """
18
+ model_type = "mert_model"
19
+
20
+ def __init__(
21
+ self,
22
+ vocab_size=32,
23
+ hidden_size=768,
24
+ num_hidden_layers=12,
25
+ num_attention_heads=12,
26
+ intermediate_size=3072,
27
+ hidden_act="gelu",
28
+ hidden_dropout=0.1,
29
+ activation_dropout=0.1,
30
+ attention_dropout=0.1,
31
+ feat_proj_layer_norm=True,
32
+ feat_proj_dropout=0.0,
33
+ final_dropout=0.1,
34
+ layerdrop=0.1,
35
+ initializer_range=0.02,
36
+ layer_norm_eps=1e-5,
37
+ feat_extract_norm="group",
38
+ feat_extract_activation="gelu",
39
+ conv_dim=(512, 512, 512, 512, 512, 512, 512),
40
+ conv_stride=(5, 2, 2, 2, 2, 2, 2),
41
+ conv_kernel=(10, 3, 3, 3, 3, 2, 2),
42
+ conv_bias=False,
43
+ num_conv_pos_embeddings=128,
44
+ num_conv_pos_embedding_groups=16,
45
+ do_stable_layer_norm=False,
46
+ apply_spec_augment=True,
47
+ mask_time_prob=0.05,
48
+ mask_time_length=10,
49
+ mask_time_min_masks=2,
50
+ mask_feature_prob=0.0,
51
+ mask_feature_length=10,
52
+ mask_feature_min_masks=0,
53
+ ctc_loss_reduction="sum",
54
+ ctc_zero_infinity=False,
55
+ use_weighted_layer_sum=False,
56
+ classifier_proj_size=256,
57
+ pad_token_id=0,
58
+ bos_token_id=1,
59
+ eos_token_id=2,
60
+ feature_extractor_cqt=False,
61
+ feature_extractor_cqt_bins=336,
62
+ deepnorm=False,
63
+ attention_relax=-1.0,
64
+ **kwargs
65
+ ):
66
+ super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
67
+ self.hidden_size = hidden_size
68
+ self.feat_extract_norm = feat_extract_norm
69
+ self.feat_extract_activation = feat_extract_activation
70
+ self.conv_dim = list(conv_dim)
71
+ self.conv_stride = list(conv_stride)
72
+ self.conv_kernel = list(conv_kernel)
73
+ self.conv_bias = conv_bias
74
+ self.num_conv_pos_embeddings = num_conv_pos_embeddings
75
+ self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
76
+ self.num_feat_extract_layers = len(self.conv_dim)
77
+ self.num_hidden_layers = num_hidden_layers
78
+ self.intermediate_size = intermediate_size
79
+ self.hidden_act = hidden_act
80
+ self.num_attention_heads = num_attention_heads
81
+ self.hidden_dropout = hidden_dropout
82
+ self.attention_dropout = attention_dropout
83
+ self.activation_dropout = activation_dropout
84
+ self.feat_proj_layer_norm = feat_proj_layer_norm
85
+ self.feat_proj_dropout = feat_proj_dropout
86
+ self.final_dropout = final_dropout
87
+ self.layerdrop = layerdrop
88
+ self.layer_norm_eps = layer_norm_eps
89
+ self.initializer_range = initializer_range
90
+ self.vocab_size = vocab_size
91
+ self.do_stable_layer_norm = do_stable_layer_norm
92
+ self.use_weighted_layer_sum = use_weighted_layer_sum
93
+ self.classifier_proj_size = classifier_proj_size
94
+
95
+ if (
96
+ (len(self.conv_stride) != self.num_feat_extract_layers)
97
+ or (len(self.conv_kernel) != self.num_feat_extract_layers)
98
+ or (len(self.conv_dim) != self.num_feat_extract_layers)
99
+ ):
100
+ raise ValueError(
101
+ "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
102
+ " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
103
+ f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
104
+ f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
105
+ )
106
+
107
+ # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
108
+ self.apply_spec_augment = apply_spec_augment
109
+ self.mask_time_prob = mask_time_prob
110
+ self.mask_time_length = mask_time_length
111
+ self.mask_time_min_masks = mask_time_min_masks
112
+ self.mask_feature_prob = mask_feature_prob
113
+ self.mask_feature_length = mask_feature_length
114
+ self.mask_feature_min_masks = mask_feature_min_masks
115
+
116
+ # ctc loss
117
+ self.ctc_loss_reduction = ctc_loss_reduction
118
+ self.ctc_zero_infinity = ctc_zero_infinity
119
+
120
+ # cqt feature extractor
121
+ self.feature_extractor_cqt = feature_extractor_cqt
122
+ self.feature_extractor_cqt_bins = feature_extractor_cqt_bins
123
+
124
+ # deepnorm: up-scale weighted residual conection + down-scale initial value transformer encoder
125
+ self.deepnorm = deepnorm
126
+
127
+ self.attention_relax = attention_relax
128
+
129
+ # fix bug with hf > 4.42
130
+ self.conv_pos_batch_norm = False
131
+
132
+ @property
133
+ def inputs_to_logits_ratio(self):
134
+ return functools.reduce(operator.mul, self.conv_stride, 1)