Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import librosa | |
| import torch | |
| import torchaudio | |
| from tqdm import tqdm | |
| import warnings | |
| SR=12000 | |
| def basic_stats_dataset(dataset): | |
| sizes = [] | |
| srs = [] | |
| labels = [] | |
| for row in dataset: | |
| signal = row["audio"]["array"] | |
| sr = row["audio"]["sampling_rate"] | |
| label = row["label"] | |
| sizes.append(signal.size) | |
| srs.append(sr) | |
| labels.append(label) | |
| sizes = np.array(sizes) | |
| srs = np.array(srs) | |
| labels = np.array(labels) | |
| return sizes, srs, labels | |
| # This function loads all data in a huggingface dataset, into a numpy array | |
| def get_raw_data(dataset, pad="constant", dtype=np.float32): | |
| signals = np.zeros((len(dataset), 36000), dtype=dtype) | |
| labels = np.zeros(len(dataset), dtype=np.uint8) | |
| sizes = np.zeros(len(dataset), dtype=int) | |
| for i, row in enumerate(dataset): | |
| signal = row["audio"]["array"] | |
| sr = row["audio"]["sampling_rate"] | |
| label = row["label"] | |
| size = signal.size | |
| # RESAMPLING to 12000 | |
| if sr != 12000: | |
| signal = librosa.resample(signal, orig_sr=sr, target_sr=12000) | |
| sr = 12000 | |
| assert sr == 12000 | |
| # Truncate signals with time > 3s | |
| if signal.size > 36000: | |
| warnings.warn("Signal > 36000. Truncate the signal") | |
| signal = signal[:36000] | |
| # PADDING short signals | |
| elif signal.size < 36000: | |
| if signal.size == 0: | |
| signal = np.zeros(36000) | |
| elif pad == "constant": | |
| signal = np.pad(signal, (0, 36000-signal.size), mode="constant", constant_values=0) | |
| else: | |
| signal = np.pad(signal, (0, 36000-signal.size), mode=pad) | |
| assert signal.size == 36000 | |
| labels[i] = label | |
| signals[i, :] = signal | |
| sizes[i] = size | |
| return signals, labels, sizes | |
| # This is a generator, doing the same as the function above but load data by batch | |
| # (lower memory usage for inference) | |
| def get_batch_generator(dataset, bs, pad="constant"): | |
| def process_signal(row): | |
| signal = row["audio"]["array"] | |
| sr = row["audio"]["sampling_rate"] | |
| label = row["label"] | |
| size = signal.size | |
| # RESAMPLING to 12000 | |
| if sr != 12000: | |
| signal = librosa.resample(signal, orig_sr=sr, target_sr=12000) | |
| sr = 12000 | |
| assert sr == 12000 | |
| # Truncate signals with time > 3s | |
| if signal.size > 36000: | |
| warnings.warn("Signal > 36000. Truncate the signal") | |
| signal = signal[:36000] | |
| # PADDING short signals | |
| elif signal.size < 36000: | |
| if signal.size == 0: | |
| signal = np.zeros(36000) | |
| elif pad == "constant": | |
| signal = np.pad(signal, (0, 36000-signal.size), mode="constant", constant_values=0) | |
| else: | |
| signal = np.pad(signal, (0, 36000-signal.size), mode=pad) | |
| assert signal.size == 36000 | |
| return signal, label, size | |
| # Initialize batch buffers | |
| batch_signals = np.zeros((bs, 36000), dtype=np.float32) | |
| batch_labels = np.zeros(bs, dtype=np.uint8) | |
| batch_sizes = np.zeros(bs, dtype=int) | |
| batch_index = 0 | |
| for row in dataset: | |
| signal, label, size = process_signal(row) | |
| batch_signals[batch_index] = signal | |
| batch_labels[batch_index] = label | |
| batch_sizes[batch_index] = size | |
| batch_index += 1 | |
| if batch_index == bs: # If the batch is full, yield it | |
| yield batch_signals, batch_labels, batch_sizes | |
| # Reset batch buffers | |
| batch_signals = np.zeros((bs, 36000), dtype=np.float32) | |
| batch_labels = np.zeros(bs, dtype=np.uint8) | |
| batch_sizes = np.zeros(bs, dtype=int) | |
| batch_index = 0 | |
| # Handle the last batch if it is not full | |
| if batch_index > 0: | |
| yield batch_signals[:batch_index], batch_labels[:batch_index], batch_sizes[:batch_index] | |
| class FeatureExtractor(): | |
| def __init__(self, xgboost_kwargs_mel_spectrogram, xgboost_kwargs_MFCC, cnn_kwargs_spectrogram, mean_spec = 0.17555018, std_spec = 0.19079028): | |
| self.mel_transform_xgboost = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=12000, | |
| **xgboost_kwargs_mel_spectrogram | |
| ).cuda() | |
| self.mel_transform_cnn = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=12000, | |
| **cnn_kwargs_spectrogram | |
| ).cuda() | |
| self.MFCC = torchaudio.transforms.MFCC( | |
| sample_rate=12000, | |
| **xgboost_kwargs_MFCC | |
| ).cuda() | |
| self.n_mfcc = xgboost_kwargs_MFCC["n_mfcc"] | |
| self.mean = mean_spec | |
| self.std = std_spec | |
| def transform(self, batch): | |
| batch = torch.as_tensor(batch).cuda() | |
| # XGBOOST features | |
| mfcc_features = np.zeros((batch.size(0), self.n_mfcc*2), dtype=np.float32) | |
| mfcc_batch = self.MFCC(batch) | |
| mfcc_features[:,:self.n_mfcc] = mfcc_batch.mean(-1).cpu().numpy() | |
| mfcc_features[:,self.n_mfcc:] = mfcc_batch.std(-1).cpu().numpy() | |
| mel_spectrograms = self.mel_transform_xgboost(batch) | |
| mel_spectrograms_delta = torchaudio.functional.compute_deltas(mel_spectrograms) | |
| e=mel_spectrograms.mean(-1) | |
| e=mel_spectrograms.mean(-1).cpu() | |
| mel_features = np.hstack(( | |
| mel_spectrograms.mean(-1).cpu(), | |
| mel_spectrograms.std(-1).cpu(), | |
| mel_spectrograms_delta.std(-1).cpu(), | |
| )) | |
| xgboost_features = np.hstack((mfcc_features, mel_features)) | |
| # CNN spectrogram | |
| spectrograms = self.mel_transform_cnn(batch) | |
| spectrograms = torch.log10(1+spectrograms) | |
| spectrograms = (spectrograms-self.mean)/self.std | |
| spectrograms = spectrograms.unsqueeze(1) | |
| #MEAN = 0.17555018 | |
| #STD = 0.19079028 | |
| return {"xgboost" : xgboost_features, "CNN": spectrograms} | |