|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def init_weights_func(m): | 
					
						
						|  | classname = m.__class__.__name__ | 
					
						
						|  | if classname.find("Conv1d") != -1: | 
					
						
						|  | torch.nn.init.xavier_uniform_(m.weight) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class LambdaLayer(nn.Module): | 
					
						
						|  | def __init__(self, lambd): | 
					
						
						|  | super(LambdaLayer, self).__init__() | 
					
						
						|  | self.lambd = lambd | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | return self.lambd(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class LayerNorm(torch.nn.LayerNorm): | 
					
						
						|  | """Layer normalization module. | 
					
						
						|  | :param int nout: output dim size | 
					
						
						|  | :param int dim: dimension to be normalized | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, nout, dim=-1, eps=1e-5): | 
					
						
						|  | """Construct an LayerNorm object.""" | 
					
						
						|  | super(LayerNorm, self).__init__(nout, eps=eps) | 
					
						
						|  | self.dim = dim | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | """Apply layer normalization. | 
					
						
						|  | :param torch.Tensor x: input tensor | 
					
						
						|  | :return: layer normalized tensor | 
					
						
						|  | :rtype torch.Tensor | 
					
						
						|  | """ | 
					
						
						|  | if self.dim == -1: | 
					
						
						|  | return super(LayerNorm, self).forward(x) | 
					
						
						|  | return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ResidualBlock(nn.Module): | 
					
						
						|  | """Implements conv->PReLU->norm n-times""" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, channels, kernel_size, dilation, n=2, norm_type='bn', dropout=0.0, | 
					
						
						|  | c_multiple=2, ln_eps=1e-12, bias=False): | 
					
						
						|  | super(ResidualBlock, self).__init__() | 
					
						
						|  |  | 
					
						
						|  | if norm_type == 'bn': | 
					
						
						|  | norm_builder = lambda: nn.BatchNorm1d(channels) | 
					
						
						|  | elif norm_type == 'in': | 
					
						
						|  | norm_builder = lambda: nn.InstanceNorm1d(channels, affine=True) | 
					
						
						|  | elif norm_type == 'gn': | 
					
						
						|  | norm_builder = lambda: nn.GroupNorm(8, channels) | 
					
						
						|  | elif norm_type == 'ln': | 
					
						
						|  | norm_builder = lambda: LayerNorm(channels, dim=1, eps=ln_eps) | 
					
						
						|  | else: | 
					
						
						|  | norm_builder = lambda: nn.Identity() | 
					
						
						|  |  | 
					
						
						|  | self.blocks = [ | 
					
						
						|  | nn.Sequential( | 
					
						
						|  | norm_builder(), | 
					
						
						|  | nn.Conv1d(channels, c_multiple * channels, kernel_size, dilation=dilation, | 
					
						
						|  | padding=(dilation * (kernel_size - 1)) // 2, bias=bias), | 
					
						
						|  | LambdaLayer(lambda x: x * kernel_size ** -0.5), | 
					
						
						|  | nn.GELU(), | 
					
						
						|  | nn.Conv1d(c_multiple * channels, channels, 1, dilation=dilation, bias=bias), | 
					
						
						|  | ) | 
					
						
						|  | for _ in range(n) | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  | self.blocks = nn.ModuleList(self.blocks) | 
					
						
						|  | self.dropout = dropout | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | nonpadding = (x.abs().sum(1) > 0).float()[:, None, :] | 
					
						
						|  | for b in self.blocks: | 
					
						
						|  | x_ = b(x) | 
					
						
						|  | if self.dropout > 0 and self.training: | 
					
						
						|  | x_ = F.dropout(x_, self.dropout, training=self.training) | 
					
						
						|  | x = x + x_ | 
					
						
						|  | x = x * nonpadding | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ConvBlocks(nn.Module): | 
					
						
						|  | """Decodes the expanded phoneme encoding into spectrograms""" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, channels, out_dims, dilations, kernel_size, | 
					
						
						|  | norm_type='ln', layers_in_block=2, c_multiple=2, | 
					
						
						|  | dropout=0.0, ln_eps=1e-5, init_weights=True, is_BTC=True, bias=False): | 
					
						
						|  | super(ConvBlocks, self).__init__() | 
					
						
						|  | self.is_BTC = is_BTC | 
					
						
						|  | self.res_blocks = nn.Sequential( | 
					
						
						|  | *[ResidualBlock(channels, kernel_size, d, | 
					
						
						|  | n=layers_in_block, norm_type=norm_type, c_multiple=c_multiple, | 
					
						
						|  | dropout=dropout, ln_eps=ln_eps, bias=bias) | 
					
						
						|  | for d in dilations], | 
					
						
						|  | ) | 
					
						
						|  | if norm_type == 'bn': | 
					
						
						|  | norm = nn.BatchNorm1d(channels) | 
					
						
						|  | elif norm_type == 'in': | 
					
						
						|  | norm = nn.InstanceNorm1d(channels, affine=True) | 
					
						
						|  | elif norm_type == 'gn': | 
					
						
						|  | norm = nn.GroupNorm(8, channels) | 
					
						
						|  | elif norm_type == 'ln': | 
					
						
						|  | norm = LayerNorm(channels, dim=1, eps=ln_eps) | 
					
						
						|  | self.last_norm = norm | 
					
						
						|  | self.post_net1 = nn.Conv1d(channels, out_dims, kernel_size=3, padding=1, bias=bias) | 
					
						
						|  | if init_weights: | 
					
						
						|  | self.apply(init_weights_func) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | :param x: [B, T, H] | 
					
						
						|  | :return:  [B, T, H] | 
					
						
						|  | """ | 
					
						
						|  | if self.is_BTC: | 
					
						
						|  | x = x.transpose(1, 2) | 
					
						
						|  | nonpadding = (x.abs().sum(1) > 0).float()[:, None, :] | 
					
						
						|  | x = self.res_blocks(x) * nonpadding | 
					
						
						|  | x = self.last_norm(x) * nonpadding | 
					
						
						|  | x = self.post_net1(x) * nonpadding | 
					
						
						|  | if self.is_BTC: | 
					
						
						|  | x = x.transpose(1, 2) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class SeqLevelConvolutionalModel(nn.Module): | 
					
						
						|  | def __init__(self, out_dim=64, dropout=0.5, audio_feat_type='ppg', backbone_type='unet', norm_type='bn'): | 
					
						
						|  | nn.Module.__init__(self) | 
					
						
						|  | self.audio_feat_type = audio_feat_type | 
					
						
						|  | if audio_feat_type == 'ppg': | 
					
						
						|  | self.audio_encoder = nn.Sequential(*[ | 
					
						
						|  | nn.Conv1d(29, 48, 3, 1, 1, bias=False), | 
					
						
						|  | nn.BatchNorm1d(48) if norm_type=='bn' else LayerNorm(48, dim=1), | 
					
						
						|  | nn.GELU(), | 
					
						
						|  | nn.Conv1d(48, 48, 3, 1, 1, bias=False) | 
					
						
						|  | ]) | 
					
						
						|  | self.energy_encoder = nn.Sequential(*[ | 
					
						
						|  | nn.Conv1d(1, 16, 3, 1, 1, bias=False), | 
					
						
						|  | nn.BatchNorm1d(16) if norm_type=='bn' else LayerNorm(16, dim=1), | 
					
						
						|  | nn.GELU(), | 
					
						
						|  | nn.Conv1d(16, 16, 3, 1, 1, bias=False) | 
					
						
						|  | ]) | 
					
						
						|  | elif audio_feat_type == 'mel': | 
					
						
						|  | self.mel_encoder = nn.Sequential(*[ | 
					
						
						|  | nn.Conv1d(80, 64, 3, 1, 1, bias=False), | 
					
						
						|  | nn.BatchNorm1d(64) if norm_type=='bn' else LayerNorm(64, dim=1), | 
					
						
						|  | nn.GELU(), | 
					
						
						|  | nn.Conv1d(64, 64, 3, 1, 1, bias=False) | 
					
						
						|  | ]) | 
					
						
						|  | else: | 
					
						
						|  | raise NotImplementedError("now only ppg or mel are supported!") | 
					
						
						|  |  | 
					
						
						|  | self.style_encoder = nn.Sequential(*[ | 
					
						
						|  | nn.Linear(135, 256), | 
					
						
						|  | nn.GELU(), | 
					
						
						|  | nn.Linear(256, 256) | 
					
						
						|  | ]) | 
					
						
						|  |  | 
					
						
						|  | if backbone_type == 'resnet': | 
					
						
						|  | self.backbone = ResNetBackbone() | 
					
						
						|  | elif backbone_type == 'unet': | 
					
						
						|  | self.backbone = UNetBackbone() | 
					
						
						|  | elif backbone_type == 'resblocks': | 
					
						
						|  | self.backbone = ResBlocksBackbone() | 
					
						
						|  | else: | 
					
						
						|  | raise NotImplementedError("Now only resnet and unet are supported!") | 
					
						
						|  |  | 
					
						
						|  | self.out_layer = nn.Sequential( | 
					
						
						|  | nn.BatchNorm1d(512) if norm_type=='bn' else LayerNorm(512, dim=1), | 
					
						
						|  | nn.Conv1d(512, 64, 3, 1, 1, bias=False), | 
					
						
						|  | nn.PReLU(), | 
					
						
						|  | nn.Conv1d(64, out_dim, 3, 1, 1, bias=False) | 
					
						
						|  | ) | 
					
						
						|  | self.feat_dropout = nn.Dropout(p=dropout) | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def device(self): | 
					
						
						|  | return self.backbone.parameters().__next__().device | 
					
						
						|  |  | 
					
						
						|  | def forward(self, batch, ret, log_dict=None): | 
					
						
						|  | style, x_mask = batch['style'].to(self.device), batch['x_mask'].to(self.device) | 
					
						
						|  | style_feat = self.style_encoder(style) | 
					
						
						|  |  | 
					
						
						|  | if self.audio_feat_type == 'ppg': | 
					
						
						|  | audio, energy = batch['audio'].to(self.device), batch['energy'].to(self.device) | 
					
						
						|  | audio_feat = self.audio_encoder(audio.transpose(1,2)).transpose(1,2) * x_mask.unsqueeze(2) | 
					
						
						|  | energy_feat = self.energy_encoder(energy.transpose(1,2)).transpose(1,2) * x_mask.unsqueeze(2) | 
					
						
						|  | feat = torch.cat([audio_feat, energy_feat], dim=2) | 
					
						
						|  | elif self.audio_feat_type == 'mel': | 
					
						
						|  | mel = batch['mel'].to(self.device) | 
					
						
						|  | feat = self.mel_encoder(mel.transpose(1,2)).transpose(1,2) * x_mask.unsqueeze(2) | 
					
						
						|  |  | 
					
						
						|  | feat, x_mask = self.backbone(x=feat, sty=style_feat, x_mask=x_mask) | 
					
						
						|  |  | 
					
						
						|  | out = self.out_layer(feat.transpose(1,2)).transpose(1,2) * x_mask.unsqueeze(2) | 
					
						
						|  |  | 
					
						
						|  | ret['pred'] = out | 
					
						
						|  | ret['mask'] = x_mask | 
					
						
						|  | return out | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ResBlocksBackbone(nn.Module): | 
					
						
						|  | def __init__(self, in_dim=64, out_dim=512, p_dropout=0.5, norm_type='bn'): | 
					
						
						|  | super(ResBlocksBackbone,self).__init__() | 
					
						
						|  | self.resblocks_0 = ConvBlocks(channels=in_dim, out_dims=64, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) | 
					
						
						|  | self.resblocks_1 = ConvBlocks(channels=64, out_dims=128, dilations=[1]*4, kernel_size=3, norm_type=norm_type, is_BTC=False) | 
					
						
						|  | self.resblocks_2 = ConvBlocks(channels=128, out_dims=256, dilations=[1]*14, kernel_size=3, norm_type=norm_type, is_BTC=False) | 
					
						
						|  | self.resblocks_3 = ConvBlocks(channels=512, out_dims=512, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) | 
					
						
						|  | self.resblocks_4 = ConvBlocks(channels=512, out_dims=out_dim, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) | 
					
						
						|  |  | 
					
						
						|  | self.downsampler = LambdaLayer(lambda x: F.interpolate(x, scale_factor=0.5, mode='linear')) | 
					
						
						|  | self.upsampler = LambdaLayer(lambda x: F.interpolate(x, scale_factor=4, mode='linear')) | 
					
						
						|  |  | 
					
						
						|  | self.dropout = nn.Dropout(p=p_dropout) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, sty, x_mask=1.): | 
					
						
						|  | """ | 
					
						
						|  | x: [B, T, C] | 
					
						
						|  | sty: [B, C=256] | 
					
						
						|  | x_mask: [B, T] | 
					
						
						|  | ret: [B, T/2, C] | 
					
						
						|  | """ | 
					
						
						|  | x = x.transpose(1, 2) | 
					
						
						|  | x_mask = x_mask[:, None, :] | 
					
						
						|  |  | 
					
						
						|  | x = self.resblocks_0(x) * x_mask | 
					
						
						|  |  | 
					
						
						|  | x_mask = self.downsampler(x_mask) | 
					
						
						|  | x = self.downsampler(x) * x_mask | 
					
						
						|  | x = self.resblocks_1(x) * x_mask | 
					
						
						|  | x = self.resblocks_2(x) * x_mask | 
					
						
						|  |  | 
					
						
						|  | x = self.dropout(x.transpose(1,2)).transpose(1,2) | 
					
						
						|  | sty = sty[:, :, None].repeat([1,1,x_mask.shape[2]]) | 
					
						
						|  | x = torch.cat([x, sty], dim=1) | 
					
						
						|  |  | 
					
						
						|  | x = self.resblocks_3(x) * x_mask | 
					
						
						|  | x = self.resblocks_4(x) * x_mask | 
					
						
						|  |  | 
					
						
						|  | x = x.transpose(1,2) | 
					
						
						|  | x_mask = x_mask.squeeze(1) | 
					
						
						|  | return x, x_mask | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ResNetBackbone(nn.Module): | 
					
						
						|  | def __init__(self, in_dim=64, out_dim=512, p_dropout=0.5, norm_type='bn'): | 
					
						
						|  | super(ResNetBackbone,self).__init__() | 
					
						
						|  | self.resblocks_0 = ConvBlocks(channels=in_dim, out_dims=64, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) | 
					
						
						|  | self.resblocks_1 = ConvBlocks(channels=64, out_dims=128, dilations=[1]*4, kernel_size=3, norm_type=norm_type, is_BTC=False) | 
					
						
						|  | self.resblocks_2 = ConvBlocks(channels=128, out_dims=256, dilations=[1]*14, kernel_size=3, norm_type=norm_type, is_BTC=False) | 
					
						
						|  | self.resblocks_3 = ConvBlocks(channels=512, out_dims=512, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) | 
					
						
						|  | self.resblocks_4 = ConvBlocks(channels=512, out_dims=out_dim, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) | 
					
						
						|  |  | 
					
						
						|  | self.downsampler = LambdaLayer(lambda x: F.interpolate(x, scale_factor=0.5, mode='linear')) | 
					
						
						|  | self.upsampler = LambdaLayer(lambda x: F.interpolate(x, scale_factor=4, mode='linear')) | 
					
						
						|  |  | 
					
						
						|  | self.dropout = nn.Dropout(p=p_dropout) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, sty, x_mask=1.): | 
					
						
						|  | """ | 
					
						
						|  | x: [B, T, C] | 
					
						
						|  | sty: [B, C=256] | 
					
						
						|  | x_mask: [B, T] | 
					
						
						|  | ret: [B, T/2, C] | 
					
						
						|  | """ | 
					
						
						|  | x = x.transpose(1, 2) | 
					
						
						|  | x_mask = x_mask[:, None, :] | 
					
						
						|  |  | 
					
						
						|  | x = self.resblocks_0(x) * x_mask | 
					
						
						|  |  | 
					
						
						|  | x_mask = self.downsampler(x_mask) | 
					
						
						|  | x = self.downsampler(x) * x_mask | 
					
						
						|  | x = self.resblocks_1(x) * x_mask | 
					
						
						|  |  | 
					
						
						|  | x_mask = self.downsampler(x_mask) | 
					
						
						|  | x = self.downsampler(x) * x_mask | 
					
						
						|  | x = self.resblocks_2(x) * x_mask | 
					
						
						|  |  | 
					
						
						|  | x_mask = self.downsampler(x_mask) | 
					
						
						|  | x = self.downsampler(x) * x_mask | 
					
						
						|  | x = self.dropout(x.transpose(1,2)).transpose(1,2) | 
					
						
						|  | sty = sty[:, :, None].repeat([1,1,x_mask.shape[2]]) | 
					
						
						|  | x = torch.cat([x, sty], dim=1) | 
					
						
						|  | x = self.resblocks_3(x) * x_mask | 
					
						
						|  |  | 
					
						
						|  | x_mask = self.upsampler(x_mask) | 
					
						
						|  | x = self.upsampler(x) * x_mask | 
					
						
						|  | x = self.resblocks_4(x) * x_mask | 
					
						
						|  |  | 
					
						
						|  | x = x.transpose(1,2) | 
					
						
						|  | x_mask = x_mask.squeeze(1) | 
					
						
						|  | return x, x_mask | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class UNetBackbone(nn.Module): | 
					
						
						|  | def __init__(self, in_dim=64, out_dim=512, p_dropout=0.5, norm_type='bn'): | 
					
						
						|  | super(UNetBackbone, self).__init__() | 
					
						
						|  | self.resblocks_0 = ConvBlocks(channels=in_dim, out_dims=64, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) | 
					
						
						|  | self.resblocks_1 = ConvBlocks(channels=64, out_dims=128, dilations=[1]*4, kernel_size=3, norm_type=norm_type, is_BTC=False) | 
					
						
						|  | self.resblocks_2 = ConvBlocks(channels=128, out_dims=256, dilations=[1]*8, kernel_size=3, norm_type=norm_type, is_BTC=False) | 
					
						
						|  | self.resblocks_3 = ConvBlocks(channels=512, out_dims=512, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) | 
					
						
						|  | self.resblocks_4 = ConvBlocks(channels=768, out_dims=512, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) | 
					
						
						|  | self.resblocks_5 = ConvBlocks(channels=640, out_dims=out_dim, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) | 
					
						
						|  |  | 
					
						
						|  | self.downsampler = nn.Upsample(scale_factor=0.5, mode='linear') | 
					
						
						|  | self.upsampler = nn.Upsample(scale_factor=2, mode='linear') | 
					
						
						|  | self.dropout = nn.Dropout(p=p_dropout) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, sty, x_mask=1.): | 
					
						
						|  | """ | 
					
						
						|  | x: [B, T, C] | 
					
						
						|  | sty: [B, C=256] | 
					
						
						|  | x_mask: [B, T] | 
					
						
						|  | ret: [B, T/2, C] | 
					
						
						|  | """ | 
					
						
						|  | x = x.transpose(1, 2) | 
					
						
						|  | x_mask = x_mask[:, None, :] | 
					
						
						|  |  | 
					
						
						|  | x0 = self.resblocks_0(x) * x_mask | 
					
						
						|  |  | 
					
						
						|  | x_mask = self.downsampler(x_mask) | 
					
						
						|  | x = self.downsampler(x0) * x_mask | 
					
						
						|  | x1 = self.resblocks_1(x) * x_mask | 
					
						
						|  |  | 
					
						
						|  | x_mask = self.downsampler(x_mask) | 
					
						
						|  | x = self.downsampler(x1) * x_mask | 
					
						
						|  | x2 = self.resblocks_2(x) * x_mask | 
					
						
						|  |  | 
					
						
						|  | x_mask = self.downsampler(x_mask) | 
					
						
						|  | x = self.downsampler(x2) * x_mask | 
					
						
						|  | x = self.dropout(x.transpose(1,2)).transpose(1,2) | 
					
						
						|  | sty = sty[:, :, None].repeat([1,1,x_mask.shape[2]]) | 
					
						
						|  | x = torch.cat([x, sty], dim=1) | 
					
						
						|  | x3 = self.resblocks_3(x) * x_mask | 
					
						
						|  |  | 
					
						
						|  | x_mask = self.upsampler(x_mask) | 
					
						
						|  | x = self.upsampler(x3) * x_mask | 
					
						
						|  | x = torch.cat([x, self.dropout(x2.transpose(1,2)).transpose(1,2)], dim=1) | 
					
						
						|  | x4 = self.resblocks_4(x) * x_mask | 
					
						
						|  |  | 
					
						
						|  | x_mask = self.upsampler(x_mask) | 
					
						
						|  | x = self.upsampler(x4) * x_mask | 
					
						
						|  | x = torch.cat([x, self.dropout(x1.transpose(1,2)).transpose(1,2)], dim=1) | 
					
						
						|  | x5 = self.resblocks_5(x) * x_mask | 
					
						
						|  |  | 
					
						
						|  | x = x5.transpose(1,2) | 
					
						
						|  | x_mask = x_mask.squeeze(1) | 
					
						
						|  | return x, x_mask | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == '__main__': | 
					
						
						|  | pass | 
					
						
						|  |  |