Dionyssos commited on
Commit
07ebc68
·
1 Parent(s): e0f0baf

student wav2small 17K params

Browse files
Files changed (1) hide show
  1. app.py +360 -74
app.py CHANGED
@@ -17,21 +17,6 @@ plt.style.use('seaborn-v0_8-whitegrid')
17
 
18
 
19
 
20
- def _prenorm(x, attention_mask=None):
21
- '''mean/var'''
22
- if attention_mask is not None:
23
- N = attention_mask.sum(1, keepdim=True) # 0=ignored 1=valid
24
- x -= x.sum(1, keepdim=True) / N
25
- var = (x * x).sum(1, keepdim=True) / N
26
-
27
- else:
28
- x -= x.mean(1, keepdim=True) # mean is an onnx operator reducemean saves some ops compared to casting integer N to float and the div
29
- var = (x * x).mean(1, keepdim=True)
30
- return x / torch.sqrt(var + 1e-7)
31
-
32
-
33
-
34
-
35
  class ADV(nn.Module):
36
 
37
  def __init__(self, config):
@@ -96,16 +81,275 @@ dawn = Dawn.from_pretrained(
96
  ).to(device).eval()
97
 
98
 
99
- def wav2small(x):
100
- return .5 * dawn(x) + .5 * base(x)
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- fig_error, ax = plt.subplots(figsize=(8, 6))
104
 
105
- # Set the text to display
106
- error_message = "Error: No .wav or Mic. audio provided."
107
 
108
- # Add the text to the plot. We'll place it in the center of the plot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  ax.text(0.5, 0.5, error_message,
110
  ha='center',
111
  va='center',
@@ -113,125 +357,164 @@ ax.text(0.5, 0.5, error_message,
113
  color='gray',
114
  fontweight='bold',
115
  transform=ax.transAxes)
116
-
117
- # Hide the axis ticks and labels for a cleaner look
118
  ax.set_xticks([])
119
  ax.set_yticks([])
120
  ax.set_xticklabels([])
121
  ax.set_yticklabels([])
122
-
123
- # Optional: Add a border around the text to make it stand out more
124
  ax.set_frame_on(True)
125
  ax.spines['top'].set_visible(False)
126
  ax.spines['right'].set_visible(False)
127
  ax.spines['bottom'].set_visible(False)
128
  ax.spines['left'].set_visible(False)
129
 
130
-
131
-
132
-
133
-
134
  def process_audio(audio_filepath):
135
  if audio_filepath is None:
136
- return fig_error
 
137
 
138
- # Load the audio file
139
- waveform, sample_rate = librosa.load(audio_filepath)
140
 
141
- # Ensure audio is mono: if stereo, take the mean across channels
142
-
143
- # Resample audio to 16kHz if necessary
144
  if sample_rate != 16000:
145
  resampled_waveform_np = audresample.resample(waveform, sample_rate, 16000)
146
- x = torch.from_numpy(resampled_waveform_np)
147
- x = x[:, :64000] # 4s
 
 
 
148
  with torch.no_grad():
 
149
  logits_dawn = dawn(x).cpu().numpy()[0, :]
150
- logits_wavlm = base(x).cpu().numpy()[0, :]
151
 
152
- logits_wav2small = .5 * logits_dawn + .5 * logits_wavlm
 
 
 
 
 
 
 
 
153
 
154
- # left_bars_data = np.array([0.75, 0.5, 0.9])
155
- # right_bars_data = np.array([0.3, 0.8, 0.65])
156
  left_bars_data = logits_dawn.clip(0, 1)
157
  right_bars_data = logits_wav2small.clip(0, 1)
158
 
159
-
160
  bar_labels = ['\nArousal', '\nDominance', '\nValence']
161
  y_pos = np.arange(len(bar_labels))
162
 
163
- # Define the base colormaps for each category to ensure a different color per row
164
- # Using Greys for Dominance as requested
165
  category_colormaps = [plt.cm.Blues, plt.cm.Greys, plt.cm.Oranges]
166
 
167
- # Define color shades for left and right for each category
168
  left_filled_colors = []
169
  right_filled_colors = []
170
  background_colors = []
171
 
 
172
  for i, cmap in enumerate(category_colormaps):
173
- # Pick a darker shade for the left filled bar
174
- left_filled_colors.append(cmap(0.74)) # 0.7
175
- # Pick a slightly lighter shade for the right filled bar
176
- right_filled_colors.append(cmap(0.64)) # 0.5
177
- # Pick a very light shade for the transparent background bar
178
  background_colors.append(cmap(0.1))
179
 
180
- # Set up the figure and axes
181
- fig, ax = plt.subplots(figsize=(10, 6))
182
-
183
- # Plot the background bars with transparency
184
  for i in range(len(bar_labels)):
185
- # Left background bar (transparent, light shade of category color)
186
  ax.barh(y_pos[i], -1, color=background_colors[i], alpha=0.3, height=0.6)
187
- # Right background bar (transparent, light shade of category color)
188
  ax.barh(y_pos[i], 1, color=background_colors[i], alpha=0.3, height=0.6)
189
 
190
- # Plot the filled bars for the left and right side
191
  for i in range(len(bar_labels)):
192
- # Left filled bar (opaque, darker shade of category color)
193
  ax.barh(y_pos[i], -left_bars_data[i], color=left_filled_colors[i], alpha=1, height=0.6)
194
- # Right filled bar (opaque, lighter shade of category color)
195
  ax.barh(y_pos[i], right_bars_data[i], color=right_filled_colors[i], alpha=1, height=0.6)
196
 
197
- # Add a central axis divider
198
  ax.axvline(0, color='black', linewidth=0.8, linestyle='--')
199
 
200
- # Set x-axis limits and y-axis ticks
201
  ax.set_xlim(-1, 1)
202
  ax.set_yticks(y_pos)
203
  ax.set_yticklabels(bar_labels, fontsize=12)
204
 
205
-
206
  def abs_tick_formatter(x, pos):
207
  return f'{int(abs(x) * 100)}%'
208
  ax.xaxis.set_major_formatter(plt.FuncFormatter(abs_tick_formatter))
209
 
210
- # Add a clean title and labels
211
  ax.set_title('', fontsize=16, pad=20)
212
- ax.set_xlabel('Outputs of Wav2Vev2 Outputs of Wav2Small Teacher', fontsize=12)
213
 
214
- # Remove the top and right spines for a cleaner look
215
  ax.spines['top'].set_visible(False)
216
  ax.spines['right'].set_visible(False)
217
  ax.spines['left'].set_visible(False)
218
 
219
- # Add annotations to the filled bars for clarity
220
  for i in range(len(bar_labels)):
221
- # Left annotation (uses left_filled_colors for text color)
222
  ax.text(-left_bars_data[i] - 0.05, y_pos[i], f'{int(left_bars_data[i] * 100)}%',
223
  va='center', ha='right', color=left_filled_colors[i], fontweight='bold')
224
- # Right annotation (uses right_filled_colors for text color)
225
  ax.text(right_bars_data[i] + 0.05, y_pos[i], f'{int(right_bars_data[i] * 100)}%',
226
  va='center', ha='left', color=right_filled_colors[i], fontweight='bold')
227
 
228
 
229
- return fig
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
 
 
 
 
232
 
 
 
 
 
 
 
 
233
 
 
 
 
 
234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
 
237
  iface = gr.Interface(
@@ -242,25 +525,27 @@ iface = gr.Interface(
242
  label=''
243
  ),
244
  outputs=[
245
- gr.Plot(label="Arousal / Dominance / Valence Plots"),
 
246
  ],
247
  title='',
248
  description='',
249
- flagging_mode="never", # save audio and .csv in the machine ?
250
  examples=[
251
  "female-46-neutral.wav",
252
  "female-20-happy.wav",
253
  "male-60-angry.wav",
254
  "male-27-sad.wav",
255
  ],
256
- css="footer {visibility: hidden}"
257
  )
258
 
 
259
  with gr.Blocks() as demo:
260
-
261
- # https://discuss.huggingface.co/t/how-to-get-the-microphone-streaming-input-file-when-using-blocks/37204/3
262
  with gr.Tab(label="Arousal / Dominance / Valence"):
263
  iface.render()
 
264
  with gr.Tab(label="CCC"):
265
  gr.Markdown('''<table style="width:500px"><tr><th colspan=5 >CCC MSP Podcast v1.7</th></tr>
266
  <tr> <td> </td><td>Arousal</td> <td>Dominance</td> <td>Valence</td> <td> Associated Paper </td> </tr>
@@ -269,5 +554,6 @@ with gr.Blocks() as demo:
269
  </table>
270
  ''')
271
 
 
272
  if __name__ == "__main__":
273
  demo.launch(share=False)
 
17
 
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  class ADV(nn.Module):
21
 
22
  def __init__(self, config):
 
81
  ).to(device).eval()
82
 
83
 
 
 
84
 
85
+ # Wav2Small
86
+
87
+
88
+
89
+ import torch
90
+ import numpy as np
91
+ import torch.nn.functional as F
92
+ import librosa
93
+ from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel, Wav2Vec2Model
94
+ from torch import nn
95
+ from transformers import PretrainedConfig
96
+
97
+
98
+ def _prenorm(x, attention_mask=None):
99
+ '''mean/var'''
100
+ if attention_mask is not None:
101
+ N = attention_mask.sum(1, keepdim=True) # here attn msk is unprocessed just the original input
102
+ x -= x.sum(1, keepdim=True) / N
103
+ var = (x * x).sum(1, keepdim=True) / N
104
+
105
+ else:
106
+ x -= x.mean(1, keepdim=True) # mean is an onnx operator reducemean saves some ops compared to casting integer N to float and the div
107
+ var = (x * x).mean(1, keepdim=True)
108
+ return x / torch.sqrt(var + 1e-7)
109
+
110
+
111
+
112
+
113
+ class Spectrogram(nn.Module):
114
+ def __init__(self,
115
+ n_fft=64, # num cols of DFT
116
+ n_time=64, # num rows of DFT matrix
117
+ hop_length=32,
118
+ freeze_parameters=True):
119
+
120
+
121
+ super().__init__()
122
+
123
+ fft_window = librosa.filters.get_window('hann', n_time, fftbins=True)
124
+
125
+ fft_window = librosa.util.pad_center(fft_window, size=n_time)
126
+
127
+
128
+
129
+
130
+
131
+ out_channels = n_fft // 2 + 1
132
+
133
+ (x, y) = np.meshgrid(np.arange(n_time), np.arange(n_fft))
134
+ omega = np.exp(-2 * np.pi * 1j / n_time)
135
+ dft_matrix = np.power(omega, x * y) # (n_fft, n_time)
136
+ dft_matrix = dft_matrix * fft_window[None, :]
137
+ dft_matrix = dft_matrix[0 : out_channels, :]
138
+ dft_matrix = dft_matrix[:, None, :]
139
+
140
+ # ---- Assymetric DFT Non Square
141
+
142
+ self.conv_real = nn.Conv1d(1, out_channels, n_fft, stride=hop_length, padding=0, bias=False)
143
+ self.conv_imag = nn.Conv1d(1, out_channels, n_fft, stride=hop_length, padding=0, bias=False)
144
+ self.conv_real.weight.data = torch.tensor(np.real(dft_matrix), dtype=self.conv_real.weight.dtype).to(self.conv_real.weight.device)
145
+ self.conv_imag.weight.data = torch.tensor(np.imag(dft_matrix), dtype=self.conv_imag.weight.dtype).to(self.conv_imag.weight.device)
146
+ if freeze_parameters:
147
+ for param in self.parameters():
148
+ param.requires_grad = False
149
+
150
+ def forward(self, input):
151
+ x = input[:, None, :]
152
+
153
+ real = self.conv_real(x)
154
+ imag = self.conv_imag(x)
155
+ return real ** 2 + imag ** 2 # bs, mel, time-frames
156
+
157
+
158
+ class LogmelFilterBank(nn.Module):
159
+ def __init__(self,
160
+ sr=16000,
161
+ n_fft=64,
162
+ n_mels=26, # maxpool
163
+ fmin=0.0,
164
+ freeze_parameters=True):
165
+
166
+ super().__init__()
167
+
168
+ fmax = sr//2
169
+
170
+ W2 = librosa.filters.mel(sr=sr,
171
+ n_fft=n_fft,
172
+ n_mels=n_mels,
173
+ fmin=fmin,
174
+ fmax=fmax).T
175
+
176
+ self.register_buffer('melW', torch.Tensor(W2))
177
+ self.register_buffer('amin', torch.Tensor([1e-10]))
178
+
179
+ def forward(self, x):
180
+
181
+ x = torch.matmul(x[:, None, :, :].transpose(2, 3), self.melW) # changes melf not num frames
182
+
183
+ x = torch.where(x > self.amin, x, self.amin) # not in place
184
+
185
+ x = 10 * torch.log10(x)
186
+ return x
187
+
188
+
189
+
190
+
191
+
192
+ def length_after_conv_layer(_length, k=None, pad=None, stride=None):
193
+ return torch.floor( (_length + 2*pad - k) / stride + 1 )
194
+
195
+
196
+
197
+
198
+
199
+
200
+ class Conv(nn.Module):
201
+
202
+ def __init__(self, c_in, c_out, k=3, stride=1, padding=1):
203
+
204
+ super().__init__()
205
+
206
+ self.conv = nn.Conv2d(c_in, c_out, k, stride=stride, padding=padding, bias=False)
207
+ self.norm = nn.BatchNorm2d(c_out)
208
+
209
+ def forward(self, x):
210
+ x = self.conv(x)
211
+ x = self.norm(x)
212
+ return torch.relu_(x)
213
 
 
214
 
 
 
215
 
216
+
217
+
218
+ class Vgg7(nn.Module):
219
+
220
+ def __init__(self):
221
+
222
+ super().__init__()
223
+
224
+ self.l1 = Conv( 1, 13)
225
+ self.l2 = Conv(13, 13)
226
+ self.l3 = Conv(13, 13)
227
+ self.maxpool_A = nn.MaxPool2d(3,
228
+ stride=2,
229
+ padding=1)
230
+ self.l4 = Conv(13, 13)
231
+ self.l5 = Conv(13, 13)
232
+ self.l6 = Conv(13, 13)
233
+ self.l7 = Conv(13, 13)
234
+ self.lin = nn.Conv2d(13, 13, 1, padding=0, stride=1)
235
+ self.sof = nn.Conv2d(13, 13, 1, padding=0, stride=1) # pool time - reshape mel into channels after pooling
236
+ self.spectrogram_extractor = Spectrogram()
237
+ self.logmel_extractor = LogmelFilterBank()
238
+
239
+ def final_length(self, L):
240
+ conv_kernel = [64, 3] # [nfft, maxpool]
241
+ conv_stride = [32, 2] # [hop_len, maxpool_stride] # consider only layers of stride > 1
242
+ conv_pad = [0, 1] # [pad_stft, pad_maxpool]
243
+ for k, stride, pad in zip(conv_kernel, conv_stride, conv_pad):
244
+ L = length_after_conv_layer(L, k=k, stride=stride, pad=pad)
245
+ return L
246
+
247
+ def final_attention_mask(self, feature_vector_length, attention_mask=None):
248
+ non_padded_lengths = attention_mask.sum(1)
249
+ out_lengths = self.final_length(non_padded_lengths) # how can non_padded_lengths get exact 0 here DOES IT MEAN ATTNMASK WAS NOT FILLED?
250
+ out_lengths = out_lengths.to(torch.long)
251
+ bs, _ = attention_mask.shape
252
+ attention_mask = torch.ones((bs, feature_vector_length),
253
+ dtype=attention_mask.dtype,
254
+ device=attention_mask.device)
255
+ for b, _len in enumerate(out_lengths):
256
+ attention_mask[b, _len:] = 0
257
+ return attention_mask
258
+
259
+ def forward(self, x, attention_mask=None):
260
+ x = _prenorm(x,
261
+ attention_mask=attention_mask)
262
+ x = self.spectrogram_extractor(x)
263
+ x = self.logmel_extractor(x)
264
+ x = self.l1(x)
265
+ x = self.l2(x)
266
+ x = self.l3(x)
267
+ x = self.maxpool_A(x) # reshape here? so these conv will have large kernel
268
+ x = self.l4(x)
269
+ x = self.l5(x)
270
+ x = self.l6(x)
271
+ x = self.l7(x)
272
+ if attention_mask is not None:
273
+ bs, _, t, _ = x.shape
274
+ a = self.final_attention_mask(feature_vector_length=t,
275
+ attention_mask=attention_mask)[:, None, :, None]
276
+ #print(a.shape, x.shape, '\n\n\n\n')
277
+ x = torch.masked_fill(x, a < 1, 0)
278
+ # mask also affects lin !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
279
+ x = self.lin(x) * ( self.sof(x) -10000. * torch.logical_not(a) ).softmax(2)
280
+ else:
281
+ x = self.lin(x) * self.sof(x).softmax(2)
282
+
283
+ x = x.sum(2) # bs, ch, time-frames, HALF_MEL -> bs, ch, HALF_MEL
284
+ # --
285
+ xT = x.transpose(1,2)
286
+ x = torch.cat([x,
287
+ torch.bmm(x, xT), # corr (chxmel) x (melxCH)
288
+ # torch.bmm(x, x), # corr ch * ch
289
+ # torch.bmm(xT, xT) # corr mel * mel
290
+ ], 2)
291
+ # --
292
+ return x.reshape(-1, 338)
293
+
294
+
295
+ class Wav2SmallConfig(PretrainedConfig):
296
+ model_type = "wav2vec2"
297
+
298
+ def __init__(self,
299
+ **kwargs):
300
+ super().__init__(**kwargs)
301
+ self.half_mel = 13
302
+ self.n_fft = 64
303
+ self.n_time = 64
304
+ self.hidden = 2 * self.half_mel * self.half_mel
305
+ self.hop = self.n_time // 2
306
+
307
+
308
+ class Wav2Small(Wav2Vec2PreTrainedModel):
309
+
310
+ def __init__(self,
311
+ config):
312
+ super().__init__(config)
313
+ self.vgg7 = Vgg7()
314
+ self.adv = nn.Linear(config.hidden, 3) # 0=arousal, 1=dominance, 2=valence
315
+
316
+ def forward(self, x, attention_mask=None):
317
+ x = self.vgg7(x, attention_mask=attention_mask)
318
+ return self.adv(x)
319
+
320
+
321
+
322
+
323
+
324
+
325
+
326
+ def _ccc(x, y):
327
+ '''if len(x) = len(y) = 1 we have 0/0 as a&b can both be negative we should add 1e-7 to denominator protecting sign of denominator
328
+ to find sign of denominator and add 1e-7 if sgn>=0 or -1e-7 if sgn<0'''
329
+
330
+ mean_y = y.mean()
331
+ mean_x = x.mean()
332
+ a = x - mean_x
333
+ b = y - mean_y
334
+ L = (mean_x - mean_y).abs() * .1 * x.shape[0]
335
+ #print(L / ((mean_x - mean_y) **2 * x.shape[0]))
336
+ numerator = torch.dot(a, b) # L term if both a,b scalars dissallows 0 numerator [OFFICIAL CCC HAS L ONLY IN D]
337
+ denominator = torch.dot(a, a) + torch.dot(b, b) + L # if both a,b are equalscalars then the dots are all zero and ccc=1
338
+ denominator = torch.where(denominator.sign() < 0,
339
+ denominator - 1e-7,
340
+ denominator + 1e-7)
341
+ ccc = numerator / denominator
342
+
343
+ return -ccc #+ F.l1_loss(a, b)
344
+
345
+
346
+
347
+ wav2small = Wav2Small.from_pretrained('audeering/wav2small').to(device).eval()
348
+
349
+
350
+ # Error figure for the first plot
351
+ fig_error, ax = plt.subplots(figsize=(8, 6))
352
+ error_message = "Error: No .wav or Mic. audio provided."
353
  ax.text(0.5, 0.5, error_message,
354
  ha='center',
355
  va='center',
 
357
  color='gray',
358
  fontweight='bold',
359
  transform=ax.transAxes)
 
 
360
  ax.set_xticks([])
361
  ax.set_yticks([])
362
  ax.set_xticklabels([])
363
  ax.set_yticklabels([])
 
 
364
  ax.set_frame_on(True)
365
  ax.spines['top'].set_visible(False)
366
  ax.spines['right'].set_visible(False)
367
  ax.spines['bottom'].set_visible(False)
368
  ax.spines['left'].set_visible(False)
369
 
 
 
 
 
370
  def process_audio(audio_filepath):
371
  if audio_filepath is None:
372
+
373
+ return fig_error, fig_error
374
 
375
+ waveform, sample_rate = librosa.load(audio_filepath, sr=None)
 
376
 
377
+ # Resample audio to 16kHz if the sample rate is different
 
 
378
  if sample_rate != 16000:
379
  resampled_waveform_np = audresample.resample(waveform, sample_rate, 16000)
380
+ else:
381
+ resampled_waveform_np = waveform[None, :]
382
+
383
+ x = torch.from_numpy(resampled_waveform_np).to(torch.float)
384
+
385
  with torch.no_grad():
386
+
387
  logits_dawn = dawn(x).cpu().numpy()[0, :]
 
388
 
389
+ logits_wavlm = base(x).cpu().numpy()[0, :]
390
+
391
+ # 17K params
392
+ logits_wav2small = wav2small(x).cpu().numpy()[0, :]
393
+
394
+
395
+ # --- Plot 1: Wav2Vec2 vs Wav2Small Teacher Outputs ---
396
+
397
+ fig, ax = plt.subplots(figsize=(10, 6))
398
 
 
 
399
  left_bars_data = logits_dawn.clip(0, 1)
400
  right_bars_data = logits_wav2small.clip(0, 1)
401
 
 
402
  bar_labels = ['\nArousal', '\nDominance', '\nValence']
403
  y_pos = np.arange(len(bar_labels))
404
 
405
+ # Define colormaps for each category to ensure distinct colors
 
406
  category_colormaps = [plt.cm.Blues, plt.cm.Greys, plt.cm.Oranges]
407
 
 
408
  left_filled_colors = []
409
  right_filled_colors = []
410
  background_colors = []
411
 
412
+ # Assign specific shades for filled bars and background bars
413
  for i, cmap in enumerate(category_colormaps):
414
+ left_filled_colors.append(cmap(0.74))
415
+ right_filled_colors.append(cmap(0.64))
 
 
 
416
  background_colors.append(cmap(0.1))
417
 
418
+ # Plot transparent background bars
 
 
 
419
  for i in range(len(bar_labels)):
 
420
  ax.barh(y_pos[i], -1, color=background_colors[i], alpha=0.3, height=0.6)
 
421
  ax.barh(y_pos[i], 1, color=background_colors[i], alpha=0.3, height=0.6)
422
 
423
+ # Plot the filled bars for actual data
424
  for i in range(len(bar_labels)):
 
425
  ax.barh(y_pos[i], -left_bars_data[i], color=left_filled_colors[i], alpha=1, height=0.6)
 
426
  ax.barh(y_pos[i], right_bars_data[i], color=right_filled_colors[i], alpha=1, height=0.6)
427
 
428
+ # Add a central vertical axis divider
429
  ax.axvline(0, color='black', linewidth=0.8, linestyle='--')
430
 
431
+ # Set x-axis limits and y-axis ticks/labels
432
  ax.set_xlim(-1, 1)
433
  ax.set_yticks(y_pos)
434
  ax.set_yticklabels(bar_labels, fontsize=12)
435
 
436
+ # Custom formatter for x-axis to show absolute percentage values
437
  def abs_tick_formatter(x, pos):
438
  return f'{int(abs(x) * 100)}%'
439
  ax.xaxis.set_major_formatter(plt.FuncFormatter(abs_tick_formatter))
440
 
441
+ # Set plot title and x-axis label
442
  ax.set_title('', fontsize=16, pad=20)
443
+ ax.set_xlabel('Wav2Vev2 (Dawn) Wav2Small (17K param.)', fontsize=12)
444
 
445
+ # Remove top, right, and left spines for a cleaner look
446
  ax.spines['top'].set_visible(False)
447
  ax.spines['right'].set_visible(False)
448
  ax.spines['left'].set_visible(False)
449
 
450
+ # Add annotations (percentage values) to the filled bars
451
  for i in range(len(bar_labels)):
 
452
  ax.text(-left_bars_data[i] - 0.05, y_pos[i], f'{int(left_bars_data[i] * 100)}%',
453
  va='center', ha='right', color=left_filled_colors[i], fontweight='bold')
 
454
  ax.text(right_bars_data[i] + 0.05, y_pos[i], f'{int(right_bars_data[i] * 100)}%',
455
  va='center', ha='left', color=right_filled_colors[i], fontweight='bold')
456
 
457
 
 
458
 
459
+ # -- PLOT 2 : WavLM / Wav2Small Teacher
460
+
461
+ fig_2, ax_2 = plt.subplots(figsize=(10, 6))
462
+
463
+
464
+ left_bars_data = logits_wavlm.clip(0, 1)
465
+ right_bars_data = (.5 * logits_dawn + .5 * logits_wavlm).clip(0, 1)
466
+
467
+ bar_labels = ['\nArousal', '\nDominance', '\nValence']
468
+ y_pos = np.arange(len(bar_labels))
469
+
470
+ # Define colormaps for each category to ensure distinct colors
471
+ category_colormaps = [plt.cm.Blues, plt.cm.Greys, plt.cm.Oranges]
472
+
473
+ left_filled_colors = []
474
+ right_filled_colors = []
475
+ background_colors = []
476
+
477
+ # Assign specific shades for filled bars and background bars
478
+ for i, cmap in enumerate(category_colormaps):
479
+ left_filled_colors.append(cmap(0.74))
480
+ right_filled_colors.append(cmap(0.64))
481
+ background_colors.append(cmap(0.1))
482
 
483
+ # Plot transparent background bars
484
+ for i in range(len(bar_labels)):
485
+ ax_2.barh(y_pos[i], -1, color=background_colors[i], alpha=0.3, height=0.6)
486
+ ax_2.barh(y_pos[i], 1, color=background_colors[i], alpha=0.3, height=0.6)
487
 
488
+ # Plot the filled bars for actual data
489
+ for i in range(len(bar_labels)):
490
+ ax_2.barh(y_pos[i], -left_bars_data[i], color=left_filled_colors[i], alpha=1, height=0.6)
491
+ ax_2.barh(y_pos[i], right_bars_data[i], color=right_filled_colors[i], alpha=1, height=0.6)
492
+
493
+ # Add a central vertical axis divider
494
+ ax_2.axvline(0, color='black', linewidth=0.8, linestyle='--')
495
 
496
+ # Set x-axis limits and y-axis ticks/labels
497
+ ax_2.set_xlim(-1, 1)
498
+ ax_2.set_yticks(y_pos)
499
+ ax_2.set_yticklabels(bar_labels, fontsize=12)
500
 
501
+ # Custom formatter for x-axis to show absolute percentage values
502
+ def abs_tick_formatter(x, pos):
503
+ return f'{int(abs(x) * 100)}%'
504
+ ax_2.xaxis.set_major_formatter(plt.FuncFormatter(abs_tick_formatter))
505
+ ax_2.set_title('', fontsize=16, pad=20)
506
+ ax_2.set_xlabel('WavLM (Baseline) Wav2Small Teacher (0.4B param.)', fontsize=12)
507
+ ax_2.spines['top'].set_visible(False)
508
+ ax_2.spines['right'].set_visible(False)
509
+ ax_2.spines['left'].set_visible(False)
510
+
511
+ # Add annotations (percentage values) to the filled bars
512
+ for i in range(len(bar_labels)):
513
+ ax_2.text(-left_bars_data[i] - 0.05, y_pos[i], f'{int(left_bars_data[i] * 100)}%',
514
+ va='center', ha='right', color=left_filled_colors[i], fontweight='bold')
515
+ ax_2.text(right_bars_data[i] + 0.05, y_pos[i], f'{int(right_bars_data[i] * 100)}%',
516
+ va='center', ha='left', color=right_filled_colors[i], fontweight='bold')
517
+ return fig, fig_2
518
 
519
 
520
  iface = gr.Interface(
 
525
  label=''
526
  ),
527
  outputs=[
528
+ gr.Plot(label="Wav2Vec2 vs Wav2Small (17K params) Plot"), # First plot output
529
+ gr.Plot(label="WavLM vs Wav2Small Teacher Plot"), # Second plot output
530
  ],
531
  title='',
532
  description='',
533
+ flagging_mode="never", # Disables flagging feature
534
  examples=[
535
  "female-46-neutral.wav",
536
  "female-20-happy.wav",
537
  "male-60-angry.wav",
538
  "male-27-sad.wav",
539
  ],
540
+ css="footer {visibility: hidden}" # Hides the Gradio footer
541
  )
542
 
543
+ # Gradio Blocks for tabbed interface
544
  with gr.Blocks() as demo:
545
+ # First tab for the existing Arousal/Dominance/Valence plots
 
546
  with gr.Tab(label="Arousal / Dominance / Valence"):
547
  iface.render()
548
+ # Second tab for CCC (Concordance Correlation Coefficient) information
549
  with gr.Tab(label="CCC"):
550
  gr.Markdown('''<table style="width:500px"><tr><th colspan=5 >CCC MSP Podcast v1.7</th></tr>
551
  <tr> <td> </td><td>Arousal</td> <td>Dominance</td> <td>Valence</td> <td> Associated Paper </td> </tr>
 
554
  </table>
555
  ''')
556
 
557
+ # Launch the Gradio application
558
  if __name__ == "__main__":
559
  demo.launch(share=False)