ORI-Muchim commited on
Commit
84001e4
1 Parent(s): d4d0d61

Update attentions.py

Browse files
Files changed (1) hide show
  1. attentions.py +159 -5
attentions.py CHANGED
@@ -1,14 +1,19 @@
 
1
  import math
 
2
  import torch
3
  from torch import nn
4
  from torch.nn import functional as F
 
5
 
6
  import commons
 
7
  from modules import LayerNorm
8
-
9
 
10
- class Encoder(nn.Module):
11
- def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs):
 
 
12
  super().__init__()
13
  self.hidden_channels = hidden_channels
14
  self.filter_channels = filter_channels
@@ -23,16 +28,32 @@ class Encoder(nn.Module):
23
  self.norm_layers_1 = nn.ModuleList()
24
  self.ffn_layers = nn.ModuleList()
25
  self.norm_layers_2 = nn.ModuleList()
 
 
 
 
 
 
 
 
 
 
 
26
  for i in range(self.n_layers):
27
  self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
28
  self.norm_layers_1.append(LayerNorm(hidden_channels))
29
  self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
30
  self.norm_layers_2.append(LayerNorm(hidden_channels))
31
 
32
- def forward(self, x, x_mask):
33
  attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
34
  x = x * x_mask
35
  for i in range(self.n_layers):
 
 
 
 
 
36
  y = self.attn_layers[i](x, x, attn_mask)
37
  y = self.drop(y)
38
  x = self.norm_layers_1[i](x + y)
@@ -43,7 +64,6 @@ class Encoder(nn.Module):
43
  x = x * x_mask
44
  return x
45
 
46
-
47
  class Decoder(nn.Module):
48
  def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs):
49
  super().__init__()
@@ -298,3 +318,137 @@ class FFN(nn.Module):
298
  padding = [[0, 0], [0, 0], [pad_l, pad_r]]
299
  x = F.pad(x, commons.convert_pad_shape(padding))
300
  return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
  import math
3
+ import numpy as np
4
  import torch
5
  from torch import nn
6
  from torch.nn import functional as F
7
+ from torch.nn.utils import remove_weight_norm, weight_norm
8
 
9
  import commons
10
+ import modules
11
  from modules import LayerNorm
 
12
 
13
+ class Encoder(nn.Module): #backward compatible vits2 encoder
14
+ def __init__(
15
+ self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs
16
+ ):
17
  super().__init__()
18
  self.hidden_channels = hidden_channels
19
  self.filter_channels = filter_channels
 
28
  self.norm_layers_1 = nn.ModuleList()
29
  self.ffn_layers = nn.ModuleList()
30
  self.norm_layers_2 = nn.ModuleList()
31
+ # if kwargs has spk_emb_dim, then add a linear layer to project spk_emb_dim to hidden_channels
32
+ self.cond_layer_idx = self.n_layers
33
+ if 'gin_channels' in kwargs:
34
+ self.gin_channels = kwargs['gin_channels']
35
+ if self.gin_channels != 0:
36
+ self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
37
+ # vits2 says 3rd block, so idx is 2 by default
38
+ self.cond_layer_idx = kwargs['cond_layer_idx'] if 'cond_layer_idx' in kwargs else 2
39
+ print(self.gin_channels, self.cond_layer_idx)
40
+ assert self.cond_layer_idx < self.n_layers, 'cond_layer_idx should be less than n_layers'
41
+
42
  for i in range(self.n_layers):
43
  self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
44
  self.norm_layers_1.append(LayerNorm(hidden_channels))
45
  self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
46
  self.norm_layers_2.append(LayerNorm(hidden_channels))
47
 
48
+ def forward(self, x, x_mask, g=None):
49
  attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
50
  x = x * x_mask
51
  for i in range(self.n_layers):
52
+ if i == self.cond_layer_idx and g is not None:
53
+ g = self.spk_emb_linear(g.transpose(1, 2))
54
+ g = g.transpose(1, 2)
55
+ x = x + g
56
+ x = x * x_mask
57
  y = self.attn_layers[i](x, x, attn_mask)
58
  y = self.drop(y)
59
  x = self.norm_layers_1[i](x + y)
 
64
  x = x * x_mask
65
  return x
66
 
 
67
  class Decoder(nn.Module):
68
  def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs):
69
  super().__init__()
 
318
  padding = [[0, 0], [0, 0], [pad_l, pad_r]]
319
  x = F.pad(x, commons.convert_pad_shape(padding))
320
  return x
321
+
322
+
323
+ class Depthwise_Separable_Conv1D(nn.Module):
324
+ def __init__(
325
+ self,
326
+ in_channels,
327
+ out_channels,
328
+ kernel_size,
329
+ stride = 1,
330
+ padding = 0,
331
+ dilation = 1,
332
+ bias = True,
333
+ padding_mode = 'zeros', # TODO: refine this type
334
+ device=None,
335
+ dtype=None
336
+ ):
337
+ super().__init__()
338
+ self.depth_conv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype)
339
+ self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype)
340
+
341
+ def forward(self, input):
342
+ return self.point_conv(self.depth_conv(input))
343
+
344
+ def weight_norm(self):
345
+ self.depth_conv = weight_norm(self.depth_conv, name = 'weight')
346
+ self.point_conv = weight_norm(self.point_conv, name = 'weight')
347
+
348
+ def remove_weight_norm(self):
349
+ self.depth_conv = remove_weight_norm(self.depth_conv, name = 'weight')
350
+ self.point_conv = remove_weight_norm(self.point_conv, name = 'weight')
351
+
352
+ class Depthwise_Separable_TransposeConv1D(nn.Module):
353
+ def __init__(
354
+ self,
355
+ in_channels,
356
+ out_channels,
357
+ kernel_size,
358
+ stride = 1,
359
+ padding = 0,
360
+ output_padding = 0,
361
+ bias = True,
362
+ dilation = 1,
363
+ padding_mode = 'zeros', # TODO: refine this type
364
+ device=None,
365
+ dtype=None
366
+ ):
367
+ super().__init__()
368
+ self.depth_conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,output_padding=output_padding,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype)
369
+ self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype)
370
+
371
+ def forward(self, input):
372
+ return self.point_conv(self.depth_conv(input))
373
+
374
+ def weight_norm(self):
375
+ self.depth_conv = weight_norm(self.depth_conv, name = 'weight')
376
+ self.point_conv = weight_norm(self.point_conv, name = 'weight')
377
+
378
+ def remove_weight_norm(self):
379
+ remove_weight_norm(self.depth_conv, name = 'weight')
380
+ remove_weight_norm(self.point_conv, name = 'weight')
381
+
382
+
383
+ def weight_norm_modules(module, name = 'weight', dim = 0):
384
+ if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D):
385
+ module.weight_norm()
386
+ return module
387
+ else:
388
+ return weight_norm(module,name,dim)
389
+
390
+ def remove_weight_norm_modules(module, name = 'weight'):
391
+ if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D):
392
+ module.remove_weight_norm()
393
+ else:
394
+ remove_weight_norm(module,name)
395
+
396
+ class FFT(nn.Module):
397
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers=1, kernel_size=1, p_dropout=0.,
398
+ proximal_bias=False, proximal_init=True, isflow = False, **kwargs):
399
+ super().__init__()
400
+ self.hidden_channels = hidden_channels
401
+ self.filter_channels = filter_channels
402
+ self.n_heads = n_heads
403
+ self.n_layers = n_layers
404
+ self.kernel_size = kernel_size
405
+ self.p_dropout = p_dropout
406
+ self.proximal_bias = proximal_bias
407
+ self.proximal_init = proximal_init
408
+ if isflow and 'gin_channels' in kwargs and kwargs["gin_channels"] > 0:
409
+ cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2*hidden_channels*n_layers, 1)
410
+ self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
411
+ self.cond_layer = weight_norm_modules(cond_layer, name='weight')
412
+ self.gin_channels = kwargs["gin_channels"]
413
+ self.drop = nn.Dropout(p_dropout)
414
+ self.self_attn_layers = nn.ModuleList()
415
+ self.norm_layers_0 = nn.ModuleList()
416
+ self.ffn_layers = nn.ModuleList()
417
+ self.norm_layers_1 = nn.ModuleList()
418
+ for i in range(self.n_layers):
419
+ self.self_attn_layers.append(
420
+ MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias,
421
+ proximal_init=proximal_init))
422
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
423
+ self.ffn_layers.append(
424
+ FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
425
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
426
+
427
+ def forward(self, x, x_mask, g = None):
428
+ """
429
+ x: decoder input
430
+ h: encoder output
431
+ """
432
+ if g is not None:
433
+ g = self.cond_layer(g)
434
+
435
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
436
+ x = x * x_mask
437
+ for i in range(self.n_layers):
438
+ if g is not None:
439
+ x = self.cond_pre(x)
440
+ cond_offset = i * 2 * self.hidden_channels
441
+ g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
442
+ x = commons.fused_add_tanh_sigmoid_multiply(
443
+ x,
444
+ g_l,
445
+ torch.IntTensor([self.hidden_channels]))
446
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
447
+ y = self.drop(y)
448
+ x = self.norm_layers_0[i](x + y)
449
+
450
+ y = self.ffn_layers[i](x, x_mask)
451
+ y = self.drop(y)
452
+ x = self.norm_layers_1[i](x + y)
453
+ x = x * x_mask
454
+ return x