Wav2Small2.0 - Arousal / Dominance / Valence

Please note that this model is for research purpose only. A commercial license can be acquired with audEERING. The model expects a raw audio signal 16KHz as input, and outputs: arousal, dominance valence in range [0, 1]. The model is created following the Wav2Small paper and has a total of 17K params.

How To

import torch
import numpy as np
import librosa
from transformers import Wav2Vec2PreTrainedModel, PretrainedConfig
from torch import nn




signal = torch.from_numpy(
    librosa.load('test.wav', sr=16000)[0])[None, :]
device = 'cpu'



def _prenorm(x, attention_mask=None):
    '''wav2vec2'''
    if attention_mask is not None:
        N = attention_mask.sum(1, keepdim=True)  # here attn msk is unprocessed just the original input
        x -= x.sum(1, keepdim=True) / N
        var = (x * x).sum(1, keepdim=True) / N

    else:
        x -= x.mean(1, keepdim=True)  # mean is an onnx operator reducemean saves some ops compared to casting integer N to float and the div
        var = (x * x).mean(1, keepdim=True)
    return x / torch.sqrt(var + 1e-7)




class Spectrogram(nn.Module):
    def __init__(self,
        n_fft=64,   # num cols of DFT
        n_time=64,  # num rows of DFT matrix
        hop_length=32,
        freeze_parameters=True):


        super().__init__()

        fft_window = librosa.filters.get_window('hann', n_time, fftbins=True)

        fft_window = librosa.util.pad_center(fft_window, size=n_time)

        
        
        

        out_channels = n_fft // 2 + 1
        
        (x, y) = np.meshgrid(np.arange(n_time), np.arange(n_fft))
        omega = np.exp(-2 * np.pi * 1j / n_time)
        dft_matrix = np.power(omega, x * y)  # (n_fft, n_time)
        dft_matrix = dft_matrix * fft_window[None, :]
        dft_matrix = dft_matrix[0 : out_channels, :]
        dft_matrix = dft_matrix[:, None, :]
        
        # ---- Assymetric DFT Non Square

        self.conv_real = nn.Conv1d(1, out_channels, n_fft, stride=hop_length,
                                   padding=0, bias=False)
        self.conv_imag = nn.Conv1d(1, out_channels, n_fft, stride=hop_length,
                                   padding=0, bias=False)
        self.conv_real.weight.data = torch.tensor(np.real(dft_matrix),
                                                    dtype=self.conv_real.weight.dtype,
                                                    device=self.conv_real.weight.device)
        self.conv_imag.weight.data = torch.tensor(np.imag(dft_matrix),
                                                    dtype=self.conv_imag.weight.dtype,
                                                    device=self.conv_imag.weight.device)
        if freeze_parameters:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, input):
        x = input[:, None, :]

        real = self.conv_real(x)
        imag = self.conv_imag(x)
        return real ** 2 + imag ** 2  # bs, freq, time-frames


class LogmelFilterBank(nn.Module):
    def __init__(self,
                sr=16000,
                n_fft=64,
                n_mels=26,  # maxpool
                fmin=0.0,
                freeze_parameters=True):

        super().__init__()

        fmax = sr//2

        W2 = librosa.filters.mel(sr=sr,
                                        n_fft=n_fft,
                                        n_mels=n_mels,
                                        fmin=fmin,
                                        fmax=fmax).T

        self.register_buffer('melW', torch.Tensor(W2))
        self.register_buffer('amin', torch.Tensor([1e-10]))

    def forward(self, x):

        x = torch.matmul(x[:, None, :, :].transpose(2, 3), self.melW)   # changes melf not num frames

        x = torch.where(x > self.amin, x, self.amin)  # not in place

        x = 10 * torch.log10(x)
        return x


class Conv(nn.Module):

    def __init__(self, c_in, c_out, k=3, stride=1, padding=1):

        super().__init__()

        self.conv = nn.Conv2d(c_in, c_out, k, stride=stride, padding=padding, bias=False)
        self.norm = nn.BatchNorm2d(c_out)

    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        return torch.relu_(x)





class Vgg7(nn.Module):

    def __init__(self):

        super().__init__()

        self.l1 = Conv( 1, 13)
        self.l2 = Conv(13, 13)
        self.l3 = Conv(13, 13)
        self.maxpool_A = nn.MaxPool2d(3,
                                    stride=2,
                                    padding=1)
        self.l4 = Conv(13, 13)
        self.l5 = Conv(13, 13)
        self.l6 = Conv(13, 13)
        self.l7 = Conv(13, 13)
        self.lin = nn.Conv2d(13, 13, 1, padding=0, stride=1)
        self.sof = nn.Conv2d(13, 13, 1, padding=0, stride=1)
        self.spectrogram_extractor = Spectrogram()
        self.logmel_extractor = LogmelFilterBank()

    def forward(self, x, attention_mask=None):
        x = _prenorm(x, attention_mask=attention_mask)
        x = self.spectrogram_extractor(x)
        x = self.logmel_extractor(x)
        x = self.l1(x)
        x = self.l2(x)
        x = self.l3(x)
        x = self.maxpool_A(x)  # reshape here? so these conv will have large kernel
        x = self.l4(x)
        x = self.l5(x)
        x = self.l6(x)
        x = self.l7(x)
        x = self.lin(x) * self.sof(x).softmax(2)   # [bs, ch, time-frams, mel]
        x = x.sum(2)
        x = torch.cat([x,
                        torch.bmm(x, x.transpose(1,2))], 2)  # cosine over mel dims
        return x.reshape(-1, 338)


class Wav2SmallConfig(PretrainedConfig):
    model_type = "wav2vec2"

    def __init__(self,
                **kwargs):
        super().__init__(**kwargs)
        self.half_mel = 13
        self.n_fft = 64
        self.n_time = 64
        self.hidden = 2 * self.half_mel * self.half_mel
        self.hop = self.n_time // 2


class Wav2Small(Wav2Vec2PreTrainedModel):

    def __init__(self,
                config):
        super().__init__(config)
        self.vgg7 = Vgg7()
        self.adv  = nn.Linear(config.hidden, 3)   # 0=arousal, 1=dominance, 2=valence

    def forward(self, x, attention_mask=None):
        x = self.vgg7(x, attention_mask=attention_mask)
        return self.adv(x)


model = Wav2Small.from_pretrained(
    'audeering/wav2small').to(device).eval()
with torch.no_grad():
    logits = model(signal.to(device))

print(f'\nArousal={logits[:, 0]}\n',
      f'Dominance={logits[:, 1]}\n',
      f'Valence={logits[:, 2]}\n')
Downloads last month
115
Safetensors
Model size
16.1k params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Space using audeering/wav2small 1