submission-audio-task / tasks /preprocess.py
kangourous's picture
Upload preprocess.py
02ed262 verified
raw
history blame
5.96 kB
import numpy as np
import librosa
import torch
import torchaudio
from tqdm import tqdm
import warnings
SR=12000
def basic_stats_dataset(dataset):
sizes = []
srs = []
labels = []
for row in dataset:
signal = row["audio"]["array"]
sr = row["audio"]["sampling_rate"]
label = row["label"]
sizes.append(signal.size)
srs.append(sr)
labels.append(label)
sizes = np.array(sizes)
srs = np.array(srs)
labels = np.array(labels)
return sizes, srs, labels
# This function loads all data in a huggingface dataset, into a numpy array
def get_raw_data(dataset, pad="constant", dtype=np.float32):
signals = np.zeros((len(dataset), 36000), dtype=dtype)
labels = np.zeros(len(dataset), dtype=np.uint8)
sizes = np.zeros(len(dataset), dtype=int)
for i, row in enumerate(dataset):
signal = row["audio"]["array"]
sr = row["audio"]["sampling_rate"]
label = row["label"]
size = signal.size
# RESAMPLING to 12000
if sr != 12000:
signal = librosa.resample(signal, orig_sr=sr, target_sr=12000)
sr = 12000
assert sr == 12000
# Truncate signals with time > 3s
if signal.size > 36000:
warnings.warn("Signal > 36000. Truncate the signal")
signal = signal[:36000]
# PADDING short signals
elif signal.size < 36000:
if signal.size == 0:
signal = np.zeros(36000)
elif pad == "constant":
signal = np.pad(signal, (0, 36000-signal.size), mode="constant", constant_values=0)
else:
signal = np.pad(signal, (0, 36000-signal.size), mode=pad)
assert signal.size == 36000
labels[i] = label
signals[i, :] = signal
sizes[i] = size
return signals, labels, sizes
# This is a generator, doing the same as the function above but load data by batch
# (lower memory usage for inference)
def get_batch_generator(dataset, bs, pad="constant"):
def process_signal(row):
signal = row["audio"]["array"]
sr = row["audio"]["sampling_rate"]
label = row["label"]
size = signal.size
# RESAMPLING to 12000
if sr != 12000:
signal = librosa.resample(signal, orig_sr=sr, target_sr=12000)
sr = 12000
assert sr == 12000
# Truncate signals with time > 3s
if signal.size > 36000:
warnings.warn("Signal > 36000. Truncate the signal")
signal = signal[:36000]
# PADDING short signals
elif signal.size < 36000:
if signal.size == 0:
signal = np.zeros(36000)
elif pad == "constant":
signal = np.pad(signal, (0, 36000-signal.size), mode="constant", constant_values=0)
else:
signal = np.pad(signal, (0, 36000-signal.size), mode=pad)
assert signal.size == 36000
return signal, label, size
# Initialize batch buffers
batch_signals = np.zeros((bs, 36000), dtype=np.float32)
batch_labels = np.zeros(bs, dtype=np.uint8)
batch_sizes = np.zeros(bs, dtype=int)
batch_index = 0
for row in dataset:
signal, label, size = process_signal(row)
batch_signals[batch_index] = signal
batch_labels[batch_index] = label
batch_sizes[batch_index] = size
batch_index += 1
if batch_index == bs: # If the batch is full, yield it
yield batch_signals, batch_labels, batch_sizes
# Reset batch buffers
batch_signals = np.zeros((bs, 36000), dtype=np.float32)
batch_labels = np.zeros(bs, dtype=np.uint8)
batch_sizes = np.zeros(bs, dtype=int)
batch_index = 0
# Handle the last batch if it is not full
if batch_index > 0:
yield batch_signals[:batch_index], batch_labels[:batch_index], batch_sizes[:batch_index]
class FeatureExtractor():
def __init__(self, xgboost_kwargs_mel_spectrogram, xgboost_kwargs_MFCC, cnn_kwargs_spectrogram, mean_spec = 0.17555018, std_spec = 0.19079028):
self.mel_transform_xgboost = torchaudio.transforms.MelSpectrogram(
sample_rate=12000,
**xgboost_kwargs_mel_spectrogram
).cuda()
self.mel_transform_cnn = torchaudio.transforms.MelSpectrogram(
sample_rate=12000,
**cnn_kwargs_spectrogram
).cuda()
self.MFCC = torchaudio.transforms.MFCC(
sample_rate=12000,
**xgboost_kwargs_MFCC
).cuda()
self.n_mfcc = xgboost_kwargs_MFCC["n_mfcc"]
self.mean = mean_spec
self.std = std_spec
def transform(self, batch):
batch = torch.as_tensor(batch).cuda()
# XGBOOST features
mfcc_features = np.zeros((batch.size(0), self.n_mfcc*2), dtype=np.float32)
mfcc_batch = self.MFCC(batch)
mfcc_features[:,:self.n_mfcc] = mfcc_batch.mean(-1).cpu().numpy()
mfcc_features[:,self.n_mfcc:] = mfcc_batch.std(-1).cpu().numpy()
mel_spectrograms = self.mel_transform_xgboost(batch)
mel_spectrograms_delta = torchaudio.functional.compute_deltas(mel_spectrograms)
e=mel_spectrograms.mean(-1)
e=mel_spectrograms.mean(-1).cpu()
mel_features = np.hstack((
mel_spectrograms.mean(-1).cpu(),
mel_spectrograms.std(-1).cpu(),
mel_spectrograms_delta.std(-1).cpu(),
))
xgboost_features = np.hstack((mfcc_features, mel_features))
# CNN spectrogram
spectrograms = self.mel_transform_cnn(batch)
spectrograms = torch.log10(1+spectrograms)
spectrograms = (spectrograms-self.mean)/self.std
spectrograms = spectrograms.unsqueeze(1)
#MEAN = 0.17555018
#STD = 0.19079028
return {"xgboost" : xgboost_features, "CNN": spectrograms}