Add configuration_MERT.py
Browse files- configuration_MERT.py +134 -0
configuration_MERT.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
MERT model configuration.
|
3 |
+
|
4 |
+
Adapted from: https://github.com/yizhilll/MERT/blob/main/scripts/mert_hf/configuration_MERT.py
|
5 |
+
"""
|
6 |
+
|
7 |
+
import functools
|
8 |
+
import operator
|
9 |
+
|
10 |
+
from transformers.configuration_utils import PretrainedConfig
|
11 |
+
from transformers.utils import logging
|
12 |
+
|
13 |
+
logger = logging.get_logger(__name__)
|
14 |
+
|
15 |
+
class MERTConfig(PretrainedConfig):
|
16 |
+
r"""
|
17 |
+
"""
|
18 |
+
model_type = "mert_model"
|
19 |
+
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
vocab_size=32,
|
23 |
+
hidden_size=768,
|
24 |
+
num_hidden_layers=12,
|
25 |
+
num_attention_heads=12,
|
26 |
+
intermediate_size=3072,
|
27 |
+
hidden_act="gelu",
|
28 |
+
hidden_dropout=0.1,
|
29 |
+
activation_dropout=0.1,
|
30 |
+
attention_dropout=0.1,
|
31 |
+
feat_proj_layer_norm=True,
|
32 |
+
feat_proj_dropout=0.0,
|
33 |
+
final_dropout=0.1,
|
34 |
+
layerdrop=0.1,
|
35 |
+
initializer_range=0.02,
|
36 |
+
layer_norm_eps=1e-5,
|
37 |
+
feat_extract_norm="group",
|
38 |
+
feat_extract_activation="gelu",
|
39 |
+
conv_dim=(512, 512, 512, 512, 512, 512, 512),
|
40 |
+
conv_stride=(5, 2, 2, 2, 2, 2, 2),
|
41 |
+
conv_kernel=(10, 3, 3, 3, 3, 2, 2),
|
42 |
+
conv_bias=False,
|
43 |
+
num_conv_pos_embeddings=128,
|
44 |
+
num_conv_pos_embedding_groups=16,
|
45 |
+
do_stable_layer_norm=False,
|
46 |
+
apply_spec_augment=True,
|
47 |
+
mask_time_prob=0.05,
|
48 |
+
mask_time_length=10,
|
49 |
+
mask_time_min_masks=2,
|
50 |
+
mask_feature_prob=0.0,
|
51 |
+
mask_feature_length=10,
|
52 |
+
mask_feature_min_masks=0,
|
53 |
+
ctc_loss_reduction="sum",
|
54 |
+
ctc_zero_infinity=False,
|
55 |
+
use_weighted_layer_sum=False,
|
56 |
+
classifier_proj_size=256,
|
57 |
+
pad_token_id=0,
|
58 |
+
bos_token_id=1,
|
59 |
+
eos_token_id=2,
|
60 |
+
feature_extractor_cqt=False,
|
61 |
+
feature_extractor_cqt_bins=336,
|
62 |
+
deepnorm=False,
|
63 |
+
attention_relax=-1.0,
|
64 |
+
**kwargs
|
65 |
+
):
|
66 |
+
super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
|
67 |
+
self.hidden_size = hidden_size
|
68 |
+
self.feat_extract_norm = feat_extract_norm
|
69 |
+
self.feat_extract_activation = feat_extract_activation
|
70 |
+
self.conv_dim = list(conv_dim)
|
71 |
+
self.conv_stride = list(conv_stride)
|
72 |
+
self.conv_kernel = list(conv_kernel)
|
73 |
+
self.conv_bias = conv_bias
|
74 |
+
self.num_conv_pos_embeddings = num_conv_pos_embeddings
|
75 |
+
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
|
76 |
+
self.num_feat_extract_layers = len(self.conv_dim)
|
77 |
+
self.num_hidden_layers = num_hidden_layers
|
78 |
+
self.intermediate_size = intermediate_size
|
79 |
+
self.hidden_act = hidden_act
|
80 |
+
self.num_attention_heads = num_attention_heads
|
81 |
+
self.hidden_dropout = hidden_dropout
|
82 |
+
self.attention_dropout = attention_dropout
|
83 |
+
self.activation_dropout = activation_dropout
|
84 |
+
self.feat_proj_layer_norm = feat_proj_layer_norm
|
85 |
+
self.feat_proj_dropout = feat_proj_dropout
|
86 |
+
self.final_dropout = final_dropout
|
87 |
+
self.layerdrop = layerdrop
|
88 |
+
self.layer_norm_eps = layer_norm_eps
|
89 |
+
self.initializer_range = initializer_range
|
90 |
+
self.vocab_size = vocab_size
|
91 |
+
self.do_stable_layer_norm = do_stable_layer_norm
|
92 |
+
self.use_weighted_layer_sum = use_weighted_layer_sum
|
93 |
+
self.classifier_proj_size = classifier_proj_size
|
94 |
+
|
95 |
+
if (
|
96 |
+
(len(self.conv_stride) != self.num_feat_extract_layers)
|
97 |
+
or (len(self.conv_kernel) != self.num_feat_extract_layers)
|
98 |
+
or (len(self.conv_dim) != self.num_feat_extract_layers)
|
99 |
+
):
|
100 |
+
raise ValueError(
|
101 |
+
"Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
|
102 |
+
" `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
|
103 |
+
f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
|
104 |
+
f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
|
105 |
+
)
|
106 |
+
|
107 |
+
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
|
108 |
+
self.apply_spec_augment = apply_spec_augment
|
109 |
+
self.mask_time_prob = mask_time_prob
|
110 |
+
self.mask_time_length = mask_time_length
|
111 |
+
self.mask_time_min_masks = mask_time_min_masks
|
112 |
+
self.mask_feature_prob = mask_feature_prob
|
113 |
+
self.mask_feature_length = mask_feature_length
|
114 |
+
self.mask_feature_min_masks = mask_feature_min_masks
|
115 |
+
|
116 |
+
# ctc loss
|
117 |
+
self.ctc_loss_reduction = ctc_loss_reduction
|
118 |
+
self.ctc_zero_infinity = ctc_zero_infinity
|
119 |
+
|
120 |
+
# cqt feature extractor
|
121 |
+
self.feature_extractor_cqt = feature_extractor_cqt
|
122 |
+
self.feature_extractor_cqt_bins = feature_extractor_cqt_bins
|
123 |
+
|
124 |
+
# deepnorm: up-scale weighted residual conection + down-scale initial value transformer encoder
|
125 |
+
self.deepnorm = deepnorm
|
126 |
+
|
127 |
+
self.attention_relax = attention_relax
|
128 |
+
|
129 |
+
# fix bug with hf > 4.42
|
130 |
+
self.conv_pos_batch_norm = False
|
131 |
+
|
132 |
+
@property
|
133 |
+
def inputs_to_logits_ratio(self):
|
134 |
+
return functools.reduce(operator.mul, self.conv_stride, 1)
|