File size: 3,619 Bytes
1bd8d95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fab6416
c6e9a13
c8fd7dc
9e7614c
 
1bd8d95
 
 
9e7614c
9161434
 
9e7614c
 
 
 
9161434
 
1bd8d95
da56fb8
 
1bd8d95
8d889ff
1bd8d95
 
 
 
9e7614c
1bd8d95
6a3daac
da56fb8
f7e0f31
 
6a3daac
9f61606
 
1bd8d95
6a3daac
 
 
8d889ff
f7e0f31
6a3daac
e2d1e1b
 
f7e0f31
 
e2d1e1b
 
 
f7e0f31
da56fb8
e2d1e1b
 
 
 
0b8d088
f7e0f31
e2d1e1b
f7e0f31
e2d1e1b
f7e0f31
0b8d088
e2d1e1b
f7e0f31
e2d1e1b
f7e0f31
6a3daac
f7e0f31
8d889ff
6a3daac
f7e0f31
9f61606
 
e2d1e1b
f7e0f31
9f61606
e2d1e1b
9f61606
f7e0f31
 
 
9f61606
8d889ff
da56fb8
8d889ff
 
 
 
 
 
a3b1af9
da56fb8
a3b1af9
f7e0f31
e2d1e1b
 
9f61606
f7e0f31
e2d1e1b
a3b1af9
f7e0f31
e2d1e1b
8d889ff
da56fb8
f7e0f31
 
 
 
6a3daac
9f61606
 
 
1bd8d95
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
---
license: cc-by-nc-sa-4.0
language:
- en
pipeline_tag: audio-classification
tags:
- wavlm
- wav2vec2
- msp-podcast
- emotion-recognition
- speech
- valence
- arousal
- dominance
- speech-emotion-recognition
- dkounadis
---

# Arousal - Dominance - Valence

Dimensional Speech Emotion Recognition model of simultaneous use of [wavlm](https://huggingface.co/3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes) / [wav2vec2.0](https://github.com/audeering/w2v2-how-to).
Achieves `0.6760566` valence CCC on [MSP Podcast Test 1](https://paperswithcode.com/sota/speech-emotion-recognition-on-msp-podcast). Used as teacher for [wav2small ..]().



**[PapersWithCode](https://paperswithcode.com/dataset/msp-podcast) / [arXiv](https://arxiv.org/abs/2408.13920)**

```
Wav2Small: Distilling Wav2Vec2 to 72K parameters for low-resource
speech emotion recognition.
D. Kounadis-Bastian, O. Schrüfer, A. Derington, H. Wierstorf,
F. Eyben, F. Burkhardt, B.W. Schuller. 2024, arXiV Preprint
```

<table style="width:500px">
  <tr><th colspan=6 align="center" >CCC MSP Podcast v1.7</th></tr>
  <tr><th colspan=3 align="center">Test 1</th><th colspan=3 align="center">Test 2</th></tr>
  <tr>   <td>Val</td> <td>Dom</td> <td>Aro</td> <td>Val</td> <td>Dom</td> <td>Aro</td> </tr>
  <tr>  <td> 0.6760566 </td> <td>0.6840044</td> <td>0.7620181</td> <td>0.4229267</td> <td>0.4684658</td> <td>0.4857733</td> </tr>
</table>
 


# HowTo
```python
import librosa
import torch
import types
import torch.nn as nn
from transformers import AutoModelForAudioClassification
from transformers.models.wav2vec2.modeling_wav2vec2 import (Wav2Vec2Model,
                                                  Wav2Vec2PreTrainedModel)


signal = torch.from_numpy(
    librosa.load('test.wav', sr=16000)[0])[None, :]
device = 'cpu'

class ADV(nn.Module):

    def __init__(self, config):

        super().__init__()

        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, x):

        x = self.dense(x)
        x = torch.tanh(x)

        return self.out_proj(x)


class Dawn(Wav2Vec2PreTrainedModel):
    r"""https://arxiv.org/abs/2203.07378"""

    def __init__(self, config):

        super().__init__(config)

        self.wav2vec2 = Wav2Vec2Model(config)
        self.classifier = ADV(config)

    def forward(self, x):
        x -= x.mean(1, keepdim=True)
        variance = (x * x).mean(1, keepdim=True) + 1e-7
        x = self.wav2vec2(x / variance.sqrt())
        return self.classifier(x.last_hidden_state.mean(1))


def _forward(self, x):
    '''x: (batch, audio-samples-16KHz)'''
    x = (x + self.config.mean) / self.config.std  # sgn
    x = self.ssl_model(x, attention_mask=None).last_hidden_state
    # pool
    h = self.pool_model.sap_linear(x).tanh()
    w = torch.matmul(h, self.pool_model.attention).softmax(1)
    mu = (x * w).sum(1)
    x = torch.cat(
        [
            mu,
            ((x * x * w).sum(1) - mu * mu).clamp(min=1e-7).sqrt()
        ], 1)
    return self.ser_model(x)


# WavLM

base = AutoModelForAudioClassification.from_pretrained(
        '3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes',
        trust_remote_code=True).to(device).eval()
base.forward = types.MethodType(_forward, base)

# Wav2Vec2

dawn = Dawn.from_pretrained(
    'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim'
).to(device).eval()


def wav2small(x):
    return .5 * dawn(x) + .5 * base(x)

pred = wav2small(signal.to(device))
print(f'Arousal={pred[0, 0]} '
      f'Dominance={pred[0, 1]} ',
      f'Valence={pred[0, 2]}')
```