saurabhati commited on
Commit
3be7cc4
·
verified ·
1 Parent(s): 4e65175

Upload feature extractor

Browse files
feature_extraction_dass.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Feature extractor class for DASS.
17
+ """
18
+ # based on https://github.com/huggingface/transformers/blob/v4.49.0/src/
19
+ # transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py
20
+ # added htk_compat=True to mel_filter_bank
21
+
22
+ from typing import List, Optional, Union
23
+
24
+ import numpy as np
25
+
26
+ from transformers.audio_utils import mel_filter_bank, spectrogram, window_function
27
+ from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
28
+ from transformers.feature_extraction_utils import BatchFeature
29
+ from transformers.utils import TensorType, is_speech_available, is_torch_available, logging
30
+
31
+
32
+ if is_speech_available():
33
+ import torchaudio.compliance.kaldi as ta_kaldi
34
+
35
+ if is_torch_available():
36
+ import torch
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ class DASSFeatureExtractor(SequenceFeatureExtractor):
43
+ r"""
44
+ Constructs a Distilled Audio State-Space (DASS) feature extractor.
45
+
46
+ This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
47
+ most of the main methods. Users should refer to this superclass for more information regarding those methods.
48
+
49
+ This class extracts mel-filter bank features from raw speech using TorchAudio if installed or using numpy
50
+ otherwise, pads/truncates them to a fixed length and normalizes them using a mean and standard deviation.
51
+
52
+ Args:
53
+ feature_size (`int`, *optional*, defaults to 1):
54
+ The feature dimension of the extracted features.
55
+ sampling_rate (`int`, *optional*, defaults to 16000):
56
+ The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
57
+ num_mel_bins (`int`, *optional*, defaults to 128):
58
+ Number of Mel-frequency bins.
59
+ max_length (`int`, *optional*, defaults to 1024):
60
+ Maximum length to which to pad/truncate the extracted features.
61
+ do_normalize (`bool`, *optional*, defaults to `True`):
62
+ Whether or not to normalize the log-Mel features using `mean` and `std`.
63
+ mean (`float`, *optional*, defaults to -4.2677393):
64
+ The mean value used to normalize the log-Mel features. Uses the AudioSet mean by default.
65
+ std (`float`, *optional*, defaults to 4.5689974):
66
+ The standard deviation value used to normalize the log-Mel features. Uses the AudioSet standard deviation
67
+ by default.
68
+ return_attention_mask (`bool`, *optional*, defaults to `False`):
69
+ Whether or not [`~ASTFeatureExtractor.__call__`] should return `attention_mask`.
70
+ """
71
+
72
+ model_input_names = ["input_values", "attention_mask"]
73
+
74
+ def __init__(
75
+ self,
76
+ feature_size=1,
77
+ sampling_rate=16000,
78
+ num_mel_bins=128,
79
+ max_length=1024,
80
+ padding_value=0.0,
81
+ do_normalize=True,
82
+ mean=-4.2677393,
83
+ std=4.5689974,
84
+ return_attention_mask=False,
85
+ **kwargs,
86
+ ):
87
+ super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
88
+ self.num_mel_bins = num_mel_bins
89
+ self.max_length = max_length
90
+ self.do_normalize = do_normalize
91
+ self.mean = mean
92
+ self.std = std
93
+ self.return_attention_mask = return_attention_mask
94
+
95
+ if not is_speech_available():
96
+ mel_filters = mel_filter_bank(
97
+ num_frequency_bins=256,
98
+ num_mel_filters=self.num_mel_bins,
99
+ min_frequency=20,
100
+ max_frequency=sampling_rate // 2,
101
+ sampling_rate=sampling_rate,
102
+ norm=None,
103
+ mel_scale="kaldi",
104
+ triangularize_in_mel_space=True,
105
+ )
106
+
107
+ self.mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))
108
+ self.window = window_function(400, "hann", periodic=False)
109
+
110
+ def _extract_fbank_features(
111
+ self,
112
+ waveform: np.ndarray,
113
+ max_length: int,
114
+ ) -> np.ndarray:
115
+ """
116
+ Get mel-filter bank features using TorchAudio.
117
+ """
118
+ if is_speech_available():
119
+ waveform = torch.from_numpy(waveform).unsqueeze(0)
120
+ waveform = waveform - waveform.mean()
121
+ fbank = ta_kaldi.fbank(
122
+ waveform,
123
+ sample_frequency=self.sampling_rate,
124
+ window_type="hanning",
125
+ num_mel_bins=self.num_mel_bins,
126
+ htk_compat=True,
127
+ )
128
+ else:
129
+ waveform = np.squeeze(waveform)
130
+ fbank = spectrogram(
131
+ waveform,
132
+ self.window,
133
+ frame_length=400,
134
+ hop_length=160,
135
+ fft_length=512,
136
+ power=2.0,
137
+ center=False,
138
+ preemphasis=0.97,
139
+ mel_filters=self.mel_filters,
140
+ log_mel="log",
141
+ mel_floor=1.192092955078125e-07,
142
+ remove_dc_offset=True,
143
+ ).T
144
+
145
+ fbank = torch.from_numpy(fbank)
146
+
147
+ n_frames = fbank.shape[0]
148
+ difference = max_length - n_frames
149
+
150
+ # pad or truncate, depending on difference
151
+ if difference > 0:
152
+ pad_module = torch.nn.ZeroPad2d((0, 0, 0, difference))
153
+ fbank = pad_module(fbank)
154
+ elif difference < 0:
155
+ fbank = fbank[0:max_length, :]
156
+
157
+ fbank = fbank.numpy()
158
+
159
+ return fbank
160
+
161
+ def normalize(self, input_values: np.ndarray) -> np.ndarray:
162
+ return (input_values - (self.mean)) / (self.std * 2)
163
+
164
+ def __call__(
165
+ self,
166
+ raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
167
+ sampling_rate: Optional[int] = None,
168
+ return_tensors: Optional[Union[str, TensorType]] = None,
169
+ **kwargs,
170
+ ) -> BatchFeature:
171
+ """
172
+ Main method to featurize and prepare for the model one or several sequence(s).
173
+
174
+ Args:
175
+ raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
176
+ The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
177
+ values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
178
+ stereo, i.e. single float per timestep.
179
+ sampling_rate (`int`, *optional*):
180
+ The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
181
+ `sampling_rate` at the forward call to prevent silent errors.
182
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
183
+ If set, will return tensors instead of list of python integers. Acceptable values are:
184
+
185
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
186
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
187
+ - `'np'`: Return Numpy `np.ndarray` objects.
188
+ """
189
+
190
+ if sampling_rate is not None:
191
+ if sampling_rate != self.sampling_rate:
192
+ raise ValueError(
193
+ f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
194
+ f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
195
+ f" {self.sampling_rate} and not {sampling_rate}."
196
+ )
197
+ else:
198
+ logger.warning(
199
+ "It is strongly recommended to pass the `sampling_rate` argument to this function. "
200
+ "Failing to do so can result in silent errors that might be hard to debug."
201
+ )
202
+
203
+ is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
204
+ if is_batched_numpy and len(raw_speech.shape) > 2:
205
+ raise ValueError(f"Only mono-channel audio is supported for input to {self}")
206
+ is_batched = is_batched_numpy or (
207
+ isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
208
+ )
209
+
210
+ if is_batched:
211
+ raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech]
212
+ elif not is_batched and not isinstance(raw_speech, np.ndarray):
213
+ raw_speech = np.asarray(raw_speech, dtype=np.float32)
214
+ elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
215
+ raw_speech = raw_speech.astype(np.float32)
216
+
217
+ # always return batch
218
+ if not is_batched:
219
+ raw_speech = [raw_speech]
220
+
221
+ # extract fbank features and pad/truncate to max_length
222
+ features = [self._extract_fbank_features(waveform, max_length=self.max_length) for waveform in raw_speech]
223
+
224
+ # convert into BatchFeature
225
+ padded_inputs = BatchFeature({"input_values": features})
226
+
227
+ # make sure list is in array format
228
+ input_values = padded_inputs.get("input_values")
229
+ if isinstance(input_values[0], list):
230
+ padded_inputs["input_values"] = [np.asarray(feature, dtype=np.float32) for feature in input_values]
231
+
232
+ # normalization
233
+ if self.do_normalize:
234
+ padded_inputs["input_values"] = [self.normalize(feature) for feature in input_values]
235
+
236
+ if return_tensors is not None:
237
+ padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
238
+
239
+ return padded_inputs
240
+
241
+
242
+ __all__ = ["DASSFeatureExtractor"]
preprocessor_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoFeatureExtractor": "feature_extraction_dass.DASSFeatureExtractor"
4
+ },
5
+ "do_normalize": true,
6
+ "feature_extractor_type": "DASSFeatureExtractor",
7
+ "feature_size": 1,
8
+ "max_length": 1024,
9
+ "mean": -4.2677393,
10
+ "num_mel_bins": 128,
11
+ "padding_side": "right",
12
+ "padding_value": 0.0,
13
+ "return_attention_mask": false,
14
+ "sampling_rate": 16000,
15
+ "std": 4.5689974
16
+ }