Wav2Small2.0 - Arousal / Dominance / Valence
Please note that this model is for research purpose only. A commercial license can be acquired with audEERING. The model expects a raw audio signal 16KHz as input, and outputs: arousal, dominance valence in range [0, 1]. The model is created following the Wav2Small paper and has a total of 17K params.
How To
import torch
import numpy as np
import librosa
from transformers import Wav2Vec2PreTrainedModel, PretrainedConfig
from torch import nn
signal = torch.from_numpy(
librosa.load('test.wav', sr=16000)[0])[None, :]
device = 'cpu'
def _prenorm(x, attention_mask=None):
'''wav2vec2'''
if attention_mask is not None:
N = attention_mask.sum(1, keepdim=True) # here attn msk is unprocessed just the original input
x -= x.sum(1, keepdim=True) / N
var = (x * x).sum(1, keepdim=True) / N
else:
x -= x.mean(1, keepdim=True) # mean is an onnx operator reducemean saves some ops compared to casting integer N to float and the div
var = (x * x).mean(1, keepdim=True)
return x / torch.sqrt(var + 1e-7)
class Spectrogram(nn.Module):
def __init__(self,
n_fft=64, # num cols of DFT
n_time=64, # num rows of DFT matrix
hop_length=32,
freeze_parameters=True):
super().__init__()
fft_window = librosa.filters.get_window('hann', n_time, fftbins=True)
fft_window = librosa.util.pad_center(fft_window, size=n_time)
out_channels = n_fft // 2 + 1
(x, y) = np.meshgrid(np.arange(n_time), np.arange(n_fft))
omega = np.exp(-2 * np.pi * 1j / n_time)
dft_matrix = np.power(omega, x * y) # (n_fft, n_time)
dft_matrix = dft_matrix * fft_window[None, :]
dft_matrix = dft_matrix[0 : out_channels, :]
dft_matrix = dft_matrix[:, None, :]
# ---- Assymetric DFT Non Square
self.conv_real = nn.Conv1d(1, out_channels, n_fft, stride=hop_length,
padding=0, bias=False)
self.conv_imag = nn.Conv1d(1, out_channels, n_fft, stride=hop_length,
padding=0, bias=False)
self.conv_real.weight.data = torch.tensor(np.real(dft_matrix),
dtype=self.conv_real.weight.dtype,
device=self.conv_real.weight.device)
self.conv_imag.weight.data = torch.tensor(np.imag(dft_matrix),
dtype=self.conv_imag.weight.dtype,
device=self.conv_imag.weight.device)
if freeze_parameters:
for param in self.parameters():
param.requires_grad = False
def forward(self, input):
x = input[:, None, :]
real = self.conv_real(x)
imag = self.conv_imag(x)
return real ** 2 + imag ** 2 # bs, freq, time-frames
class LogmelFilterBank(nn.Module):
def __init__(self,
sr=16000,
n_fft=64,
n_mels=26, # maxpool
fmin=0.0,
freeze_parameters=True):
super().__init__()
fmax = sr//2
W2 = librosa.filters.mel(sr=sr,
n_fft=n_fft,
n_mels=n_mels,
fmin=fmin,
fmax=fmax).T
self.register_buffer('melW', torch.Tensor(W2))
self.register_buffer('amin', torch.Tensor([1e-10]))
def forward(self, x):
x = torch.matmul(x[:, None, :, :].transpose(2, 3), self.melW) # changes melf not num frames
x = torch.where(x > self.amin, x, self.amin) # not in place
x = 10 * torch.log10(x)
return x
class Conv(nn.Module):
def __init__(self, c_in, c_out, k=3, stride=1, padding=1):
super().__init__()
self.conv = nn.Conv2d(c_in, c_out, k, stride=stride, padding=padding, bias=False)
self.norm = nn.BatchNorm2d(c_out)
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
return torch.relu_(x)
class Vgg7(nn.Module):
def __init__(self):
super().__init__()
self.l1 = Conv( 1, 13)
self.l2 = Conv(13, 13)
self.l3 = Conv(13, 13)
self.maxpool_A = nn.MaxPool2d(3,
stride=2,
padding=1)
self.l4 = Conv(13, 13)
self.l5 = Conv(13, 13)
self.l6 = Conv(13, 13)
self.l7 = Conv(13, 13)
self.lin = nn.Conv2d(13, 13, 1, padding=0, stride=1)
self.sof = nn.Conv2d(13, 13, 1, padding=0, stride=1)
self.spectrogram_extractor = Spectrogram()
self.logmel_extractor = LogmelFilterBank()
def forward(self, x, attention_mask=None):
x = _prenorm(x, attention_mask=attention_mask)
x = self.spectrogram_extractor(x)
x = self.logmel_extractor(x)
x = self.l1(x)
x = self.l2(x)
x = self.l3(x)
x = self.maxpool_A(x) # reshape here? so these conv will have large kernel
x = self.l4(x)
x = self.l5(x)
x = self.l6(x)
x = self.l7(x)
x = self.lin(x) * self.sof(x).softmax(2) # [bs, ch, time-frams, mel]
x = x.sum(2)
x = torch.cat([x,
torch.bmm(x, x.transpose(1,2))], 2) # cosine over mel dims
return x.reshape(-1, 338)
class Wav2SmallConfig(PretrainedConfig):
model_type = "wav2vec2"
def __init__(self,
**kwargs):
super().__init__(**kwargs)
self.half_mel = 13
self.n_fft = 64
self.n_time = 64
self.hidden = 2 * self.half_mel * self.half_mel
self.hop = self.n_time // 2
class Wav2Small(Wav2Vec2PreTrainedModel):
def __init__(self,
config):
super().__init__(config)
self.vgg7 = Vgg7()
self.adv = nn.Linear(config.hidden, 3) # 0=arousal, 1=dominance, 2=valence
def forward(self, x, attention_mask=None):
x = self.vgg7(x, attention_mask=attention_mask)
return self.adv(x)
model = Wav2Small.from_pretrained(
'audeering/wav2small').to(device).eval()
with torch.no_grad():
logits = model(signal.to(device))
print(f'\nArousal={logits[:, 0]}\n',
f'Dominance={logits[:, 1]}\n',
f'Valence={logits[:, 2]}\n')
- Downloads last month
- 115