kangourous commited on
Commit
02ed262
·
verified ·
1 Parent(s): c3b44f3

Upload preprocess.py

Browse files
Files changed (1) hide show
  1. tasks/preprocess.py +177 -0
tasks/preprocess.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import librosa
3
+ import torch
4
+ import torchaudio
5
+ from tqdm import tqdm
6
+ import warnings
7
+
8
+ SR=12000
9
+
10
+ def basic_stats_dataset(dataset):
11
+ sizes = []
12
+ srs = []
13
+ labels = []
14
+
15
+ for row in dataset:
16
+ signal = row["audio"]["array"]
17
+ sr = row["audio"]["sampling_rate"]
18
+ label = row["label"]
19
+
20
+ sizes.append(signal.size)
21
+ srs.append(sr)
22
+ labels.append(label)
23
+
24
+ sizes = np.array(sizes)
25
+ srs = np.array(srs)
26
+ labels = np.array(labels)
27
+ return sizes, srs, labels
28
+
29
+
30
+ # This function loads all data in a huggingface dataset, into a numpy array
31
+ def get_raw_data(dataset, pad="constant", dtype=np.float32):
32
+ signals = np.zeros((len(dataset), 36000), dtype=dtype)
33
+ labels = np.zeros(len(dataset), dtype=np.uint8)
34
+ sizes = np.zeros(len(dataset), dtype=int)
35
+
36
+ for i, row in enumerate(dataset):
37
+ signal = row["audio"]["array"]
38
+ sr = row["audio"]["sampling_rate"]
39
+ label = row["label"]
40
+ size = signal.size
41
+
42
+ # RESAMPLING to 12000
43
+ if sr != 12000:
44
+ signal = librosa.resample(signal, orig_sr=sr, target_sr=12000)
45
+ sr = 12000
46
+ assert sr == 12000
47
+
48
+ # Truncate signals with time > 3s
49
+ if signal.size > 36000:
50
+ warnings.warn("Signal > 36000. Truncate the signal")
51
+ signal = signal[:36000]
52
+
53
+ # PADDING short signals
54
+ elif signal.size < 36000:
55
+ if signal.size == 0:
56
+ signal = np.zeros(36000)
57
+ elif pad == "constant":
58
+ signal = np.pad(signal, (0, 36000-signal.size), mode="constant", constant_values=0)
59
+ else:
60
+ signal = np.pad(signal, (0, 36000-signal.size), mode=pad)
61
+ assert signal.size == 36000
62
+
63
+ labels[i] = label
64
+ signals[i, :] = signal
65
+ sizes[i] = size
66
+
67
+ return signals, labels, sizes
68
+
69
+
70
+ # This is a generator, doing the same as the function above but load data by batch
71
+ # (lower memory usage for inference)
72
+ def get_batch_generator(dataset, bs, pad="constant"):
73
+ def process_signal(row):
74
+ signal = row["audio"]["array"]
75
+ sr = row["audio"]["sampling_rate"]
76
+ label = row["label"]
77
+ size = signal.size
78
+
79
+ # RESAMPLING to 12000
80
+ if sr != 12000:
81
+ signal = librosa.resample(signal, orig_sr=sr, target_sr=12000)
82
+ sr = 12000
83
+ assert sr == 12000
84
+
85
+ # Truncate signals with time > 3s
86
+ if signal.size > 36000:
87
+ warnings.warn("Signal > 36000. Truncate the signal")
88
+ signal = signal[:36000]
89
+
90
+ # PADDING short signals
91
+ elif signal.size < 36000:
92
+ if signal.size == 0:
93
+ signal = np.zeros(36000)
94
+ elif pad == "constant":
95
+ signal = np.pad(signal, (0, 36000-signal.size), mode="constant", constant_values=0)
96
+ else:
97
+ signal = np.pad(signal, (0, 36000-signal.size), mode=pad)
98
+ assert signal.size == 36000
99
+
100
+ return signal, label, size
101
+
102
+ # Initialize batch buffers
103
+ batch_signals = np.zeros((bs, 36000), dtype=np.float32)
104
+ batch_labels = np.zeros(bs, dtype=np.uint8)
105
+ batch_sizes = np.zeros(bs, dtype=int)
106
+ batch_index = 0
107
+
108
+ for row in dataset:
109
+ signal, label, size = process_signal(row)
110
+ batch_signals[batch_index] = signal
111
+ batch_labels[batch_index] = label
112
+ batch_sizes[batch_index] = size
113
+ batch_index += 1
114
+
115
+ if batch_index == bs: # If the batch is full, yield it
116
+ yield batch_signals, batch_labels, batch_sizes
117
+ # Reset batch buffers
118
+ batch_signals = np.zeros((bs, 36000), dtype=np.float32)
119
+ batch_labels = np.zeros(bs, dtype=np.uint8)
120
+ batch_sizes = np.zeros(bs, dtype=int)
121
+ batch_index = 0
122
+
123
+ # Handle the last batch if it is not full
124
+ if batch_index > 0:
125
+ yield batch_signals[:batch_index], batch_labels[:batch_index], batch_sizes[:batch_index]
126
+
127
+ class FeatureExtractor():
128
+ def __init__(self, xgboost_kwargs_mel_spectrogram, xgboost_kwargs_MFCC, cnn_kwargs_spectrogram, mean_spec = 0.17555018, std_spec = 0.19079028):
129
+ self.mel_transform_xgboost = torchaudio.transforms.MelSpectrogram(
130
+ sample_rate=12000,
131
+ **xgboost_kwargs_mel_spectrogram
132
+ ).cuda()
133
+
134
+ self.mel_transform_cnn = torchaudio.transforms.MelSpectrogram(
135
+ sample_rate=12000,
136
+ **cnn_kwargs_spectrogram
137
+ ).cuda()
138
+
139
+ self.MFCC = torchaudio.transforms.MFCC(
140
+ sample_rate=12000,
141
+ **xgboost_kwargs_MFCC
142
+ ).cuda()
143
+
144
+ self.n_mfcc = xgboost_kwargs_MFCC["n_mfcc"]
145
+ self.mean = mean_spec
146
+ self.std = std_spec
147
+
148
+ def transform(self, batch):
149
+ batch = torch.as_tensor(batch).cuda()
150
+
151
+ # XGBOOST features
152
+ mfcc_features = np.zeros((batch.size(0), self.n_mfcc*2), dtype=np.float32)
153
+ mfcc_batch = self.MFCC(batch)
154
+ mfcc_features[:,:self.n_mfcc] = mfcc_batch.mean(-1).cpu().numpy()
155
+ mfcc_features[:,self.n_mfcc:] = mfcc_batch.std(-1).cpu().numpy()
156
+
157
+ mel_spectrograms = self.mel_transform_xgboost(batch)
158
+ mel_spectrograms_delta = torchaudio.functional.compute_deltas(mel_spectrograms)
159
+ e=mel_spectrograms.mean(-1)
160
+ e=mel_spectrograms.mean(-1).cpu()
161
+ mel_features = np.hstack((
162
+ mel_spectrograms.mean(-1).cpu(),
163
+ mel_spectrograms.std(-1).cpu(),
164
+ mel_spectrograms_delta.std(-1).cpu(),
165
+ ))
166
+ xgboost_features = np.hstack((mfcc_features, mel_features))
167
+
168
+ # CNN spectrogram
169
+ spectrograms = self.mel_transform_cnn(batch)
170
+ spectrograms = torch.log10(1+spectrograms)
171
+ spectrograms = (spectrograms-self.mean)/self.std
172
+ spectrograms = spectrograms.unsqueeze(1)
173
+ #MEAN = 0.17555018
174
+ #STD = 0.19079028
175
+
176
+ return {"xgboost" : xgboost_features, "CNN": spectrograms}
177
+