Spaces:
Running
Running
student wav2small 17K params
Browse files
app.py
CHANGED
@@ -17,21 +17,6 @@ plt.style.use('seaborn-v0_8-whitegrid')
|
|
17 |
|
18 |
|
19 |
|
20 |
-
def _prenorm(x, attention_mask=None):
|
21 |
-
'''mean/var'''
|
22 |
-
if attention_mask is not None:
|
23 |
-
N = attention_mask.sum(1, keepdim=True) # 0=ignored 1=valid
|
24 |
-
x -= x.sum(1, keepdim=True) / N
|
25 |
-
var = (x * x).sum(1, keepdim=True) / N
|
26 |
-
|
27 |
-
else:
|
28 |
-
x -= x.mean(1, keepdim=True) # mean is an onnx operator reducemean saves some ops compared to casting integer N to float and the div
|
29 |
-
var = (x * x).mean(1, keepdim=True)
|
30 |
-
return x / torch.sqrt(var + 1e-7)
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
class ADV(nn.Module):
|
36 |
|
37 |
def __init__(self, config):
|
@@ -96,16 +81,275 @@ dawn = Dawn.from_pretrained(
|
|
96 |
).to(device).eval()
|
97 |
|
98 |
|
99 |
-
def wav2small(x):
|
100 |
-
return .5 * dawn(x) + .5 * base(x)
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
-
fig_error, ax = plt.subplots(figsize=(8, 6))
|
104 |
|
105 |
-
# Set the text to display
|
106 |
-
error_message = "Error: No .wav or Mic. audio provided."
|
107 |
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
ax.text(0.5, 0.5, error_message,
|
110 |
ha='center',
|
111 |
va='center',
|
@@ -113,125 +357,164 @@ ax.text(0.5, 0.5, error_message,
|
|
113 |
color='gray',
|
114 |
fontweight='bold',
|
115 |
transform=ax.transAxes)
|
116 |
-
|
117 |
-
# Hide the axis ticks and labels for a cleaner look
|
118 |
ax.set_xticks([])
|
119 |
ax.set_yticks([])
|
120 |
ax.set_xticklabels([])
|
121 |
ax.set_yticklabels([])
|
122 |
-
|
123 |
-
# Optional: Add a border around the text to make it stand out more
|
124 |
ax.set_frame_on(True)
|
125 |
ax.spines['top'].set_visible(False)
|
126 |
ax.spines['right'].set_visible(False)
|
127 |
ax.spines['bottom'].set_visible(False)
|
128 |
ax.spines['left'].set_visible(False)
|
129 |
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
def process_audio(audio_filepath):
|
135 |
if audio_filepath is None:
|
136 |
-
|
|
|
137 |
|
138 |
-
|
139 |
-
waveform, sample_rate = librosa.load(audio_filepath)
|
140 |
|
141 |
-
#
|
142 |
-
|
143 |
-
# Resample audio to 16kHz if necessary
|
144 |
if sample_rate != 16000:
|
145 |
resampled_waveform_np = audresample.resample(waveform, sample_rate, 16000)
|
146 |
-
|
147 |
-
|
|
|
|
|
|
|
148 |
with torch.no_grad():
|
|
|
149 |
logits_dawn = dawn(x).cpu().numpy()[0, :]
|
150 |
-
logits_wavlm = base(x).cpu().numpy()[0, :]
|
151 |
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
-
# left_bars_data = np.array([0.75, 0.5, 0.9])
|
155 |
-
# right_bars_data = np.array([0.3, 0.8, 0.65])
|
156 |
left_bars_data = logits_dawn.clip(0, 1)
|
157 |
right_bars_data = logits_wav2small.clip(0, 1)
|
158 |
|
159 |
-
|
160 |
bar_labels = ['\nArousal', '\nDominance', '\nValence']
|
161 |
y_pos = np.arange(len(bar_labels))
|
162 |
|
163 |
-
# Define
|
164 |
-
# Using Greys for Dominance as requested
|
165 |
category_colormaps = [plt.cm.Blues, plt.cm.Greys, plt.cm.Oranges]
|
166 |
|
167 |
-
# Define color shades for left and right for each category
|
168 |
left_filled_colors = []
|
169 |
right_filled_colors = []
|
170 |
background_colors = []
|
171 |
|
|
|
172 |
for i, cmap in enumerate(category_colormaps):
|
173 |
-
|
174 |
-
|
175 |
-
# Pick a slightly lighter shade for the right filled bar
|
176 |
-
right_filled_colors.append(cmap(0.64)) # 0.5
|
177 |
-
# Pick a very light shade for the transparent background bar
|
178 |
background_colors.append(cmap(0.1))
|
179 |
|
180 |
-
#
|
181 |
-
fig, ax = plt.subplots(figsize=(10, 6))
|
182 |
-
|
183 |
-
# Plot the background bars with transparency
|
184 |
for i in range(len(bar_labels)):
|
185 |
-
# Left background bar (transparent, light shade of category color)
|
186 |
ax.barh(y_pos[i], -1, color=background_colors[i], alpha=0.3, height=0.6)
|
187 |
-
# Right background bar (transparent, light shade of category color)
|
188 |
ax.barh(y_pos[i], 1, color=background_colors[i], alpha=0.3, height=0.6)
|
189 |
|
190 |
-
# Plot the filled bars for
|
191 |
for i in range(len(bar_labels)):
|
192 |
-
# Left filled bar (opaque, darker shade of category color)
|
193 |
ax.barh(y_pos[i], -left_bars_data[i], color=left_filled_colors[i], alpha=1, height=0.6)
|
194 |
-
# Right filled bar (opaque, lighter shade of category color)
|
195 |
ax.barh(y_pos[i], right_bars_data[i], color=right_filled_colors[i], alpha=1, height=0.6)
|
196 |
|
197 |
-
# Add a central axis divider
|
198 |
ax.axvline(0, color='black', linewidth=0.8, linestyle='--')
|
199 |
|
200 |
-
# Set x-axis limits and y-axis ticks
|
201 |
ax.set_xlim(-1, 1)
|
202 |
ax.set_yticks(y_pos)
|
203 |
ax.set_yticklabels(bar_labels, fontsize=12)
|
204 |
|
205 |
-
|
206 |
def abs_tick_formatter(x, pos):
|
207 |
return f'{int(abs(x) * 100)}%'
|
208 |
ax.xaxis.set_major_formatter(plt.FuncFormatter(abs_tick_formatter))
|
209 |
|
210 |
-
#
|
211 |
ax.set_title('', fontsize=16, pad=20)
|
212 |
-
ax.set_xlabel('
|
213 |
|
214 |
-
# Remove
|
215 |
ax.spines['top'].set_visible(False)
|
216 |
ax.spines['right'].set_visible(False)
|
217 |
ax.spines['left'].set_visible(False)
|
218 |
|
219 |
-
# Add annotations to the filled bars
|
220 |
for i in range(len(bar_labels)):
|
221 |
-
# Left annotation (uses left_filled_colors for text color)
|
222 |
ax.text(-left_bars_data[i] - 0.05, y_pos[i], f'{int(left_bars_data[i] * 100)}%',
|
223 |
va='center', ha='right', color=left_filled_colors[i], fontweight='bold')
|
224 |
-
# Right annotation (uses right_filled_colors for text color)
|
225 |
ax.text(right_bars_data[i] + 0.05, y_pos[i], f'{int(right_bars_data[i] * 100)}%',
|
226 |
va='center', ha='left', color=right_filled_colors[i], fontweight='bold')
|
227 |
|
228 |
|
229 |
-
return fig
|
230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
|
|
|
|
|
|
|
|
|
232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
|
|
|
|
|
|
|
|
234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
|
236 |
|
237 |
iface = gr.Interface(
|
@@ -242,25 +525,27 @@ iface = gr.Interface(
|
|
242 |
label=''
|
243 |
),
|
244 |
outputs=[
|
245 |
-
gr.Plot(label="
|
|
|
246 |
],
|
247 |
title='',
|
248 |
description='',
|
249 |
-
flagging_mode="never", #
|
250 |
examples=[
|
251 |
"female-46-neutral.wav",
|
252 |
"female-20-happy.wav",
|
253 |
"male-60-angry.wav",
|
254 |
"male-27-sad.wav",
|
255 |
],
|
256 |
-
css="footer {visibility: hidden}"
|
257 |
)
|
258 |
|
|
|
259 |
with gr.Blocks() as demo:
|
260 |
-
|
261 |
-
# https://discuss.huggingface.co/t/how-to-get-the-microphone-streaming-input-file-when-using-blocks/37204/3
|
262 |
with gr.Tab(label="Arousal / Dominance / Valence"):
|
263 |
iface.render()
|
|
|
264 |
with gr.Tab(label="CCC"):
|
265 |
gr.Markdown('''<table style="width:500px"><tr><th colspan=5 >CCC MSP Podcast v1.7</th></tr>
|
266 |
<tr> <td> </td><td>Arousal</td> <td>Dominance</td> <td>Valence</td> <td> Associated Paper </td> </tr>
|
@@ -269,5 +554,6 @@ with gr.Blocks() as demo:
|
|
269 |
</table>
|
270 |
''')
|
271 |
|
|
|
272 |
if __name__ == "__main__":
|
273 |
demo.launch(share=False)
|
|
|
17 |
|
18 |
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
class ADV(nn.Module):
|
21 |
|
22 |
def __init__(self, config):
|
|
|
81 |
).to(device).eval()
|
82 |
|
83 |
|
|
|
|
|
84 |
|
85 |
+
# Wav2Small
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
import torch
|
90 |
+
import numpy as np
|
91 |
+
import torch.nn.functional as F
|
92 |
+
import librosa
|
93 |
+
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel, Wav2Vec2Model
|
94 |
+
from torch import nn
|
95 |
+
from transformers import PretrainedConfig
|
96 |
+
|
97 |
+
|
98 |
+
def _prenorm(x, attention_mask=None):
|
99 |
+
'''mean/var'''
|
100 |
+
if attention_mask is not None:
|
101 |
+
N = attention_mask.sum(1, keepdim=True) # here attn msk is unprocessed just the original input
|
102 |
+
x -= x.sum(1, keepdim=True) / N
|
103 |
+
var = (x * x).sum(1, keepdim=True) / N
|
104 |
+
|
105 |
+
else:
|
106 |
+
x -= x.mean(1, keepdim=True) # mean is an onnx operator reducemean saves some ops compared to casting integer N to float and the div
|
107 |
+
var = (x * x).mean(1, keepdim=True)
|
108 |
+
return x / torch.sqrt(var + 1e-7)
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
class Spectrogram(nn.Module):
|
114 |
+
def __init__(self,
|
115 |
+
n_fft=64, # num cols of DFT
|
116 |
+
n_time=64, # num rows of DFT matrix
|
117 |
+
hop_length=32,
|
118 |
+
freeze_parameters=True):
|
119 |
+
|
120 |
+
|
121 |
+
super().__init__()
|
122 |
+
|
123 |
+
fft_window = librosa.filters.get_window('hann', n_time, fftbins=True)
|
124 |
+
|
125 |
+
fft_window = librosa.util.pad_center(fft_window, size=n_time)
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
out_channels = n_fft // 2 + 1
|
132 |
+
|
133 |
+
(x, y) = np.meshgrid(np.arange(n_time), np.arange(n_fft))
|
134 |
+
omega = np.exp(-2 * np.pi * 1j / n_time)
|
135 |
+
dft_matrix = np.power(omega, x * y) # (n_fft, n_time)
|
136 |
+
dft_matrix = dft_matrix * fft_window[None, :]
|
137 |
+
dft_matrix = dft_matrix[0 : out_channels, :]
|
138 |
+
dft_matrix = dft_matrix[:, None, :]
|
139 |
+
|
140 |
+
# ---- Assymetric DFT Non Square
|
141 |
+
|
142 |
+
self.conv_real = nn.Conv1d(1, out_channels, n_fft, stride=hop_length, padding=0, bias=False)
|
143 |
+
self.conv_imag = nn.Conv1d(1, out_channels, n_fft, stride=hop_length, padding=0, bias=False)
|
144 |
+
self.conv_real.weight.data = torch.tensor(np.real(dft_matrix), dtype=self.conv_real.weight.dtype).to(self.conv_real.weight.device)
|
145 |
+
self.conv_imag.weight.data = torch.tensor(np.imag(dft_matrix), dtype=self.conv_imag.weight.dtype).to(self.conv_imag.weight.device)
|
146 |
+
if freeze_parameters:
|
147 |
+
for param in self.parameters():
|
148 |
+
param.requires_grad = False
|
149 |
+
|
150 |
+
def forward(self, input):
|
151 |
+
x = input[:, None, :]
|
152 |
+
|
153 |
+
real = self.conv_real(x)
|
154 |
+
imag = self.conv_imag(x)
|
155 |
+
return real ** 2 + imag ** 2 # bs, mel, time-frames
|
156 |
+
|
157 |
+
|
158 |
+
class LogmelFilterBank(nn.Module):
|
159 |
+
def __init__(self,
|
160 |
+
sr=16000,
|
161 |
+
n_fft=64,
|
162 |
+
n_mels=26, # maxpool
|
163 |
+
fmin=0.0,
|
164 |
+
freeze_parameters=True):
|
165 |
+
|
166 |
+
super().__init__()
|
167 |
+
|
168 |
+
fmax = sr//2
|
169 |
+
|
170 |
+
W2 = librosa.filters.mel(sr=sr,
|
171 |
+
n_fft=n_fft,
|
172 |
+
n_mels=n_mels,
|
173 |
+
fmin=fmin,
|
174 |
+
fmax=fmax).T
|
175 |
+
|
176 |
+
self.register_buffer('melW', torch.Tensor(W2))
|
177 |
+
self.register_buffer('amin', torch.Tensor([1e-10]))
|
178 |
+
|
179 |
+
def forward(self, x):
|
180 |
+
|
181 |
+
x = torch.matmul(x[:, None, :, :].transpose(2, 3), self.melW) # changes melf not num frames
|
182 |
+
|
183 |
+
x = torch.where(x > self.amin, x, self.amin) # not in place
|
184 |
+
|
185 |
+
x = 10 * torch.log10(x)
|
186 |
+
return x
|
187 |
+
|
188 |
+
|
189 |
+
|
190 |
+
|
191 |
+
|
192 |
+
def length_after_conv_layer(_length, k=None, pad=None, stride=None):
|
193 |
+
return torch.floor( (_length + 2*pad - k) / stride + 1 )
|
194 |
+
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
class Conv(nn.Module):
|
201 |
+
|
202 |
+
def __init__(self, c_in, c_out, k=3, stride=1, padding=1):
|
203 |
+
|
204 |
+
super().__init__()
|
205 |
+
|
206 |
+
self.conv = nn.Conv2d(c_in, c_out, k, stride=stride, padding=padding, bias=False)
|
207 |
+
self.norm = nn.BatchNorm2d(c_out)
|
208 |
+
|
209 |
+
def forward(self, x):
|
210 |
+
x = self.conv(x)
|
211 |
+
x = self.norm(x)
|
212 |
+
return torch.relu_(x)
|
213 |
|
|
|
214 |
|
|
|
|
|
215 |
|
216 |
+
|
217 |
+
|
218 |
+
class Vgg7(nn.Module):
|
219 |
+
|
220 |
+
def __init__(self):
|
221 |
+
|
222 |
+
super().__init__()
|
223 |
+
|
224 |
+
self.l1 = Conv( 1, 13)
|
225 |
+
self.l2 = Conv(13, 13)
|
226 |
+
self.l3 = Conv(13, 13)
|
227 |
+
self.maxpool_A = nn.MaxPool2d(3,
|
228 |
+
stride=2,
|
229 |
+
padding=1)
|
230 |
+
self.l4 = Conv(13, 13)
|
231 |
+
self.l5 = Conv(13, 13)
|
232 |
+
self.l6 = Conv(13, 13)
|
233 |
+
self.l7 = Conv(13, 13)
|
234 |
+
self.lin = nn.Conv2d(13, 13, 1, padding=0, stride=1)
|
235 |
+
self.sof = nn.Conv2d(13, 13, 1, padding=0, stride=1) # pool time - reshape mel into channels after pooling
|
236 |
+
self.spectrogram_extractor = Spectrogram()
|
237 |
+
self.logmel_extractor = LogmelFilterBank()
|
238 |
+
|
239 |
+
def final_length(self, L):
|
240 |
+
conv_kernel = [64, 3] # [nfft, maxpool]
|
241 |
+
conv_stride = [32, 2] # [hop_len, maxpool_stride] # consider only layers of stride > 1
|
242 |
+
conv_pad = [0, 1] # [pad_stft, pad_maxpool]
|
243 |
+
for k, stride, pad in zip(conv_kernel, conv_stride, conv_pad):
|
244 |
+
L = length_after_conv_layer(L, k=k, stride=stride, pad=pad)
|
245 |
+
return L
|
246 |
+
|
247 |
+
def final_attention_mask(self, feature_vector_length, attention_mask=None):
|
248 |
+
non_padded_lengths = attention_mask.sum(1)
|
249 |
+
out_lengths = self.final_length(non_padded_lengths) # how can non_padded_lengths get exact 0 here DOES IT MEAN ATTNMASK WAS NOT FILLED?
|
250 |
+
out_lengths = out_lengths.to(torch.long)
|
251 |
+
bs, _ = attention_mask.shape
|
252 |
+
attention_mask = torch.ones((bs, feature_vector_length),
|
253 |
+
dtype=attention_mask.dtype,
|
254 |
+
device=attention_mask.device)
|
255 |
+
for b, _len in enumerate(out_lengths):
|
256 |
+
attention_mask[b, _len:] = 0
|
257 |
+
return attention_mask
|
258 |
+
|
259 |
+
def forward(self, x, attention_mask=None):
|
260 |
+
x = _prenorm(x,
|
261 |
+
attention_mask=attention_mask)
|
262 |
+
x = self.spectrogram_extractor(x)
|
263 |
+
x = self.logmel_extractor(x)
|
264 |
+
x = self.l1(x)
|
265 |
+
x = self.l2(x)
|
266 |
+
x = self.l3(x)
|
267 |
+
x = self.maxpool_A(x) # reshape here? so these conv will have large kernel
|
268 |
+
x = self.l4(x)
|
269 |
+
x = self.l5(x)
|
270 |
+
x = self.l6(x)
|
271 |
+
x = self.l7(x)
|
272 |
+
if attention_mask is not None:
|
273 |
+
bs, _, t, _ = x.shape
|
274 |
+
a = self.final_attention_mask(feature_vector_length=t,
|
275 |
+
attention_mask=attention_mask)[:, None, :, None]
|
276 |
+
#print(a.shape, x.shape, '\n\n\n\n')
|
277 |
+
x = torch.masked_fill(x, a < 1, 0)
|
278 |
+
# mask also affects lin !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
279 |
+
x = self.lin(x) * ( self.sof(x) -10000. * torch.logical_not(a) ).softmax(2)
|
280 |
+
else:
|
281 |
+
x = self.lin(x) * self.sof(x).softmax(2)
|
282 |
+
|
283 |
+
x = x.sum(2) # bs, ch, time-frames, HALF_MEL -> bs, ch, HALF_MEL
|
284 |
+
# --
|
285 |
+
xT = x.transpose(1,2)
|
286 |
+
x = torch.cat([x,
|
287 |
+
torch.bmm(x, xT), # corr (chxmel) x (melxCH)
|
288 |
+
# torch.bmm(x, x), # corr ch * ch
|
289 |
+
# torch.bmm(xT, xT) # corr mel * mel
|
290 |
+
], 2)
|
291 |
+
# --
|
292 |
+
return x.reshape(-1, 338)
|
293 |
+
|
294 |
+
|
295 |
+
class Wav2SmallConfig(PretrainedConfig):
|
296 |
+
model_type = "wav2vec2"
|
297 |
+
|
298 |
+
def __init__(self,
|
299 |
+
**kwargs):
|
300 |
+
super().__init__(**kwargs)
|
301 |
+
self.half_mel = 13
|
302 |
+
self.n_fft = 64
|
303 |
+
self.n_time = 64
|
304 |
+
self.hidden = 2 * self.half_mel * self.half_mel
|
305 |
+
self.hop = self.n_time // 2
|
306 |
+
|
307 |
+
|
308 |
+
class Wav2Small(Wav2Vec2PreTrainedModel):
|
309 |
+
|
310 |
+
def __init__(self,
|
311 |
+
config):
|
312 |
+
super().__init__(config)
|
313 |
+
self.vgg7 = Vgg7()
|
314 |
+
self.adv = nn.Linear(config.hidden, 3) # 0=arousal, 1=dominance, 2=valence
|
315 |
+
|
316 |
+
def forward(self, x, attention_mask=None):
|
317 |
+
x = self.vgg7(x, attention_mask=attention_mask)
|
318 |
+
return self.adv(x)
|
319 |
+
|
320 |
+
|
321 |
+
|
322 |
+
|
323 |
+
|
324 |
+
|
325 |
+
|
326 |
+
def _ccc(x, y):
|
327 |
+
'''if len(x) = len(y) = 1 we have 0/0 as a&b can both be negative we should add 1e-7 to denominator protecting sign of denominator
|
328 |
+
to find sign of denominator and add 1e-7 if sgn>=0 or -1e-7 if sgn<0'''
|
329 |
+
|
330 |
+
mean_y = y.mean()
|
331 |
+
mean_x = x.mean()
|
332 |
+
a = x - mean_x
|
333 |
+
b = y - mean_y
|
334 |
+
L = (mean_x - mean_y).abs() * .1 * x.shape[0]
|
335 |
+
#print(L / ((mean_x - mean_y) **2 * x.shape[0]))
|
336 |
+
numerator = torch.dot(a, b) # L term if both a,b scalars dissallows 0 numerator [OFFICIAL CCC HAS L ONLY IN D]
|
337 |
+
denominator = torch.dot(a, a) + torch.dot(b, b) + L # if both a,b are equalscalars then the dots are all zero and ccc=1
|
338 |
+
denominator = torch.where(denominator.sign() < 0,
|
339 |
+
denominator - 1e-7,
|
340 |
+
denominator + 1e-7)
|
341 |
+
ccc = numerator / denominator
|
342 |
+
|
343 |
+
return -ccc #+ F.l1_loss(a, b)
|
344 |
+
|
345 |
+
|
346 |
+
|
347 |
+
wav2small = Wav2Small.from_pretrained('audeering/wav2small').to(device).eval()
|
348 |
+
|
349 |
+
|
350 |
+
# Error figure for the first plot
|
351 |
+
fig_error, ax = plt.subplots(figsize=(8, 6))
|
352 |
+
error_message = "Error: No .wav or Mic. audio provided."
|
353 |
ax.text(0.5, 0.5, error_message,
|
354 |
ha='center',
|
355 |
va='center',
|
|
|
357 |
color='gray',
|
358 |
fontweight='bold',
|
359 |
transform=ax.transAxes)
|
|
|
|
|
360 |
ax.set_xticks([])
|
361 |
ax.set_yticks([])
|
362 |
ax.set_xticklabels([])
|
363 |
ax.set_yticklabels([])
|
|
|
|
|
364 |
ax.set_frame_on(True)
|
365 |
ax.spines['top'].set_visible(False)
|
366 |
ax.spines['right'].set_visible(False)
|
367 |
ax.spines['bottom'].set_visible(False)
|
368 |
ax.spines['left'].set_visible(False)
|
369 |
|
|
|
|
|
|
|
|
|
370 |
def process_audio(audio_filepath):
|
371 |
if audio_filepath is None:
|
372 |
+
|
373 |
+
return fig_error, fig_error
|
374 |
|
375 |
+
waveform, sample_rate = librosa.load(audio_filepath, sr=None)
|
|
|
376 |
|
377 |
+
# Resample audio to 16kHz if the sample rate is different
|
|
|
|
|
378 |
if sample_rate != 16000:
|
379 |
resampled_waveform_np = audresample.resample(waveform, sample_rate, 16000)
|
380 |
+
else:
|
381 |
+
resampled_waveform_np = waveform[None, :]
|
382 |
+
|
383 |
+
x = torch.from_numpy(resampled_waveform_np).to(torch.float)
|
384 |
+
|
385 |
with torch.no_grad():
|
386 |
+
|
387 |
logits_dawn = dawn(x).cpu().numpy()[0, :]
|
|
|
388 |
|
389 |
+
logits_wavlm = base(x).cpu().numpy()[0, :]
|
390 |
+
|
391 |
+
# 17K params
|
392 |
+
logits_wav2small = wav2small(x).cpu().numpy()[0, :]
|
393 |
+
|
394 |
+
|
395 |
+
# --- Plot 1: Wav2Vec2 vs Wav2Small Teacher Outputs ---
|
396 |
+
|
397 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
398 |
|
|
|
|
|
399 |
left_bars_data = logits_dawn.clip(0, 1)
|
400 |
right_bars_data = logits_wav2small.clip(0, 1)
|
401 |
|
|
|
402 |
bar_labels = ['\nArousal', '\nDominance', '\nValence']
|
403 |
y_pos = np.arange(len(bar_labels))
|
404 |
|
405 |
+
# Define colormaps for each category to ensure distinct colors
|
|
|
406 |
category_colormaps = [plt.cm.Blues, plt.cm.Greys, plt.cm.Oranges]
|
407 |
|
|
|
408 |
left_filled_colors = []
|
409 |
right_filled_colors = []
|
410 |
background_colors = []
|
411 |
|
412 |
+
# Assign specific shades for filled bars and background bars
|
413 |
for i, cmap in enumerate(category_colormaps):
|
414 |
+
left_filled_colors.append(cmap(0.74))
|
415 |
+
right_filled_colors.append(cmap(0.64))
|
|
|
|
|
|
|
416 |
background_colors.append(cmap(0.1))
|
417 |
|
418 |
+
# Plot transparent background bars
|
|
|
|
|
|
|
419 |
for i in range(len(bar_labels)):
|
|
|
420 |
ax.barh(y_pos[i], -1, color=background_colors[i], alpha=0.3, height=0.6)
|
|
|
421 |
ax.barh(y_pos[i], 1, color=background_colors[i], alpha=0.3, height=0.6)
|
422 |
|
423 |
+
# Plot the filled bars for actual data
|
424 |
for i in range(len(bar_labels)):
|
|
|
425 |
ax.barh(y_pos[i], -left_bars_data[i], color=left_filled_colors[i], alpha=1, height=0.6)
|
|
|
426 |
ax.barh(y_pos[i], right_bars_data[i], color=right_filled_colors[i], alpha=1, height=0.6)
|
427 |
|
428 |
+
# Add a central vertical axis divider
|
429 |
ax.axvline(0, color='black', linewidth=0.8, linestyle='--')
|
430 |
|
431 |
+
# Set x-axis limits and y-axis ticks/labels
|
432 |
ax.set_xlim(-1, 1)
|
433 |
ax.set_yticks(y_pos)
|
434 |
ax.set_yticklabels(bar_labels, fontsize=12)
|
435 |
|
436 |
+
# Custom formatter for x-axis to show absolute percentage values
|
437 |
def abs_tick_formatter(x, pos):
|
438 |
return f'{int(abs(x) * 100)}%'
|
439 |
ax.xaxis.set_major_formatter(plt.FuncFormatter(abs_tick_formatter))
|
440 |
|
441 |
+
# Set plot title and x-axis label
|
442 |
ax.set_title('', fontsize=16, pad=20)
|
443 |
+
ax.set_xlabel('Wav2Vev2 (Dawn) Wav2Small (17K param.)', fontsize=12)
|
444 |
|
445 |
+
# Remove top, right, and left spines for a cleaner look
|
446 |
ax.spines['top'].set_visible(False)
|
447 |
ax.spines['right'].set_visible(False)
|
448 |
ax.spines['left'].set_visible(False)
|
449 |
|
450 |
+
# Add annotations (percentage values) to the filled bars
|
451 |
for i in range(len(bar_labels)):
|
|
|
452 |
ax.text(-left_bars_data[i] - 0.05, y_pos[i], f'{int(left_bars_data[i] * 100)}%',
|
453 |
va='center', ha='right', color=left_filled_colors[i], fontweight='bold')
|
|
|
454 |
ax.text(right_bars_data[i] + 0.05, y_pos[i], f'{int(right_bars_data[i] * 100)}%',
|
455 |
va='center', ha='left', color=right_filled_colors[i], fontweight='bold')
|
456 |
|
457 |
|
|
|
458 |
|
459 |
+
# -- PLOT 2 : WavLM / Wav2Small Teacher
|
460 |
+
|
461 |
+
fig_2, ax_2 = plt.subplots(figsize=(10, 6))
|
462 |
+
|
463 |
+
|
464 |
+
left_bars_data = logits_wavlm.clip(0, 1)
|
465 |
+
right_bars_data = (.5 * logits_dawn + .5 * logits_wavlm).clip(0, 1)
|
466 |
+
|
467 |
+
bar_labels = ['\nArousal', '\nDominance', '\nValence']
|
468 |
+
y_pos = np.arange(len(bar_labels))
|
469 |
+
|
470 |
+
# Define colormaps for each category to ensure distinct colors
|
471 |
+
category_colormaps = [plt.cm.Blues, plt.cm.Greys, plt.cm.Oranges]
|
472 |
+
|
473 |
+
left_filled_colors = []
|
474 |
+
right_filled_colors = []
|
475 |
+
background_colors = []
|
476 |
+
|
477 |
+
# Assign specific shades for filled bars and background bars
|
478 |
+
for i, cmap in enumerate(category_colormaps):
|
479 |
+
left_filled_colors.append(cmap(0.74))
|
480 |
+
right_filled_colors.append(cmap(0.64))
|
481 |
+
background_colors.append(cmap(0.1))
|
482 |
|
483 |
+
# Plot transparent background bars
|
484 |
+
for i in range(len(bar_labels)):
|
485 |
+
ax_2.barh(y_pos[i], -1, color=background_colors[i], alpha=0.3, height=0.6)
|
486 |
+
ax_2.barh(y_pos[i], 1, color=background_colors[i], alpha=0.3, height=0.6)
|
487 |
|
488 |
+
# Plot the filled bars for actual data
|
489 |
+
for i in range(len(bar_labels)):
|
490 |
+
ax_2.barh(y_pos[i], -left_bars_data[i], color=left_filled_colors[i], alpha=1, height=0.6)
|
491 |
+
ax_2.barh(y_pos[i], right_bars_data[i], color=right_filled_colors[i], alpha=1, height=0.6)
|
492 |
+
|
493 |
+
# Add a central vertical axis divider
|
494 |
+
ax_2.axvline(0, color='black', linewidth=0.8, linestyle='--')
|
495 |
|
496 |
+
# Set x-axis limits and y-axis ticks/labels
|
497 |
+
ax_2.set_xlim(-1, 1)
|
498 |
+
ax_2.set_yticks(y_pos)
|
499 |
+
ax_2.set_yticklabels(bar_labels, fontsize=12)
|
500 |
|
501 |
+
# Custom formatter for x-axis to show absolute percentage values
|
502 |
+
def abs_tick_formatter(x, pos):
|
503 |
+
return f'{int(abs(x) * 100)}%'
|
504 |
+
ax_2.xaxis.set_major_formatter(plt.FuncFormatter(abs_tick_formatter))
|
505 |
+
ax_2.set_title('', fontsize=16, pad=20)
|
506 |
+
ax_2.set_xlabel('WavLM (Baseline) Wav2Small Teacher (0.4B param.)', fontsize=12)
|
507 |
+
ax_2.spines['top'].set_visible(False)
|
508 |
+
ax_2.spines['right'].set_visible(False)
|
509 |
+
ax_2.spines['left'].set_visible(False)
|
510 |
+
|
511 |
+
# Add annotations (percentage values) to the filled bars
|
512 |
+
for i in range(len(bar_labels)):
|
513 |
+
ax_2.text(-left_bars_data[i] - 0.05, y_pos[i], f'{int(left_bars_data[i] * 100)}%',
|
514 |
+
va='center', ha='right', color=left_filled_colors[i], fontweight='bold')
|
515 |
+
ax_2.text(right_bars_data[i] + 0.05, y_pos[i], f'{int(right_bars_data[i] * 100)}%',
|
516 |
+
va='center', ha='left', color=right_filled_colors[i], fontweight='bold')
|
517 |
+
return fig, fig_2
|
518 |
|
519 |
|
520 |
iface = gr.Interface(
|
|
|
525 |
label=''
|
526 |
),
|
527 |
outputs=[
|
528 |
+
gr.Plot(label="Wav2Vec2 vs Wav2Small (17K params) Plot"), # First plot output
|
529 |
+
gr.Plot(label="WavLM vs Wav2Small Teacher Plot"), # Second plot output
|
530 |
],
|
531 |
title='',
|
532 |
description='',
|
533 |
+
flagging_mode="never", # Disables flagging feature
|
534 |
examples=[
|
535 |
"female-46-neutral.wav",
|
536 |
"female-20-happy.wav",
|
537 |
"male-60-angry.wav",
|
538 |
"male-27-sad.wav",
|
539 |
],
|
540 |
+
css="footer {visibility: hidden}" # Hides the Gradio footer
|
541 |
)
|
542 |
|
543 |
+
# Gradio Blocks for tabbed interface
|
544 |
with gr.Blocks() as demo:
|
545 |
+
# First tab for the existing Arousal/Dominance/Valence plots
|
|
|
546 |
with gr.Tab(label="Arousal / Dominance / Valence"):
|
547 |
iface.render()
|
548 |
+
# Second tab for CCC (Concordance Correlation Coefficient) information
|
549 |
with gr.Tab(label="CCC"):
|
550 |
gr.Markdown('''<table style="width:500px"><tr><th colspan=5 >CCC MSP Podcast v1.7</th></tr>
|
551 |
<tr> <td> </td><td>Arousal</td> <td>Dominance</td> <td>Valence</td> <td> Associated Paper </td> </tr>
|
|
|
554 |
</table>
|
555 |
''')
|
556 |
|
557 |
+
# Launch the Gradio application
|
558 |
if __name__ == "__main__":
|
559 |
demo.launch(share=False)
|