dkounadis commited on
Commit
0b8d088
1 Parent(s): a3b1af9
Files changed (1) hide show
  1. README.md +16 -19
README.md CHANGED
@@ -50,39 +50,35 @@ import torch
50
  import types
51
  import torch.nn as nn
52
 
53
- # speech signal 16 KHz
54
- signal = torch.rand((1, 16000))
55
  device = 'cpu'
56
 
57
- class RegressionHead(nn.Module):
58
- def __init__(self):
59
 
 
60
  super().__init__()
61
  self.dense = nn.Linear(1024, 1024)
62
  self.out_proj = nn.Linear(1024, 3)
63
 
64
  def forward(self, x):
65
- x = self.dense(x)
66
- return self.out_proj(x.tanh())
67
 
68
  class Dawn(Wav2Vec2PreTrainedModel):
69
- def __init__(self, config):
70
 
 
71
  super().__init__(config)
72
-
73
  self.wav2vec2 = Wav2Vec2Model(config)
74
- self.classifier = RegressionHead()
75
 
76
  def forward(self, x):
77
- '''x: (batch, audio-samples-16KHz)'''
78
  x = x - x.mean(1, keepdim=True)
79
  variance = (x * x).mean(1, keepdim=True) + 1e-7
80
  x = self.wav2vec2(x / variance.sqrt())[0]
81
  return self.classifier(x.mean(1)).clip(0, 1)
82
 
83
- def _forward(self, x):
84
- '''x: (batch, audio-samples-16KHz)'''
85
- x = (x + self.config.mean) / self.config.std # plus
86
  x = self.ssl_model(x, attention_mask=None).last_hidden_state
87
  # pool
88
  h = self.pool_model.sap_linear(x).tanh()
@@ -100,23 +96,24 @@ def _forward(self, x):
100
  # WavLM
101
 
102
  base = AutoModelForAudioClassification.from_pretrained(
103
- '3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes',
104
- trust_remote_code=True).to(device).eval()
105
- base.forward = types.MethodType(_forward, base)
106
 
107
  # Wav2Vec2.0
108
 
109
  dawn = Dawn.from_pretrained(
110
- 'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim'
111
  ).to(device).eval()
112
 
113
 
114
  def wav2small(x):
 
115
  return .5 * dawn(x) + .5 * base(x)
116
 
117
 
118
  with torch.no_grad():
119
  pred = wav2small(signal.to(device))
120
- print(f'\nArousal = {pred[0, 0]} Dominance= {pred[0, 1]}',
121
- f' Valence = {pred[0, 2]}')
122
  ```
 
50
  import types
51
  import torch.nn as nn
52
 
53
+ signal = torch.rand((1, 16000)) # audio signal 16 KHz
 
54
  device = 'cpu'
55
 
56
+ class ADV(nn.Module):
 
57
 
58
+ def __init__(self):
59
  super().__init__()
60
  self.dense = nn.Linear(1024, 1024)
61
  self.out_proj = nn.Linear(1024, 3)
62
 
63
  def forward(self, x):
64
+ x = self.dense(x).tanh()
65
+ return self.out_proj(x)
66
 
67
  class Dawn(Wav2Vec2PreTrainedModel):
 
68
 
69
+ def __init__(self, config):
70
  super().__init__(config)
 
71
  self.wav2vec2 = Wav2Vec2Model(config)
72
+ self.classifier = ADV()
73
 
74
  def forward(self, x):
 
75
  x = x - x.mean(1, keepdim=True)
76
  variance = (x * x).mean(1, keepdim=True) + 1e-7
77
  x = self.wav2vec2(x / variance.sqrt())[0]
78
  return self.classifier(x.mean(1)).clip(0, 1)
79
 
80
+ def _fast(self, x):
81
+ x = (x + self.config.mean) / self.config.std # sign
 
82
  x = self.ssl_model(x, attention_mask=None).last_hidden_state
83
  # pool
84
  h = self.pool_model.sap_linear(x).tanh()
 
96
  # WavLM
97
 
98
  base = AutoModelForAudioClassification.from_pretrained(
99
+ '3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes',
100
+ trust_remote_code=True).to(device).eval()
101
+ base.forward = types.MethodType(_fast, base)
102
 
103
  # Wav2Vec2.0
104
 
105
  dawn = Dawn.from_pretrained(
106
+ 'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim'
107
  ).to(device).eval()
108
 
109
 
110
  def wav2small(x):
111
+ '''x: (batch, audio-samples-16KHz)'''
112
  return .5 * dawn(x) + .5 * base(x)
113
 
114
 
115
  with torch.no_grad():
116
  pred = wav2small(signal.to(device))
117
+ print(f'arousal={pred[0, 0]} dominance={pred[0, 1]}',
118
+ f'valence={pred[0, 2]}')
119
  ```