Spaces:
Paused
Paused
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # This code is modified from https://github.com/jaywalnut310/vits/blob/main/models.pyimport torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| import math | |
| from modules.flow.modules import * | |
| class StochasticDurationPredictor(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| filter_channels, | |
| kernel_size, | |
| p_dropout, | |
| n_flows=4, | |
| gin_channels=0, | |
| ): | |
| super().__init__() | |
| filter_channels = in_channels | |
| self.in_channels = in_channels | |
| self.filter_channels = filter_channels | |
| self.kernel_size = kernel_size | |
| self.p_dropout = p_dropout | |
| self.n_flows = n_flows | |
| self.gin_channels = gin_channels | |
| self.log_flow = Log() | |
| self.flows = nn.ModuleList() | |
| self.flows.append(ElementwiseAffine(2)) | |
| for i in range(n_flows): | |
| self.flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3)) | |
| self.flows.append(Flip()) | |
| self.post_pre = nn.Conv1d(1, filter_channels, 1) | |
| self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) | |
| self.post_convs = DDSConv( | |
| filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout | |
| ) | |
| self.post_flows = nn.ModuleList() | |
| self.post_flows.append(ElementwiseAffine(2)) | |
| for i in range(4): | |
| self.post_flows.append( | |
| ConvFlow(2, filter_channels, kernel_size, n_layers=3) | |
| ) | |
| self.post_flows.append(Flip()) | |
| self.pre = nn.Conv1d(in_channels, filter_channels, 1) | |
| self.proj = nn.Conv1d(filter_channels, filter_channels, 1) | |
| self.convs = DDSConv( | |
| filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout | |
| ) | |
| if gin_channels != 0: | |
| self.cond = nn.Conv1d(gin_channels, filter_channels, 1) | |
| def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): | |
| x = torch.detach(x) | |
| x = self.pre(x) | |
| if g is not None: | |
| g = torch.detach(g) | |
| x = x + self.cond(g) | |
| x = self.convs(x, x_mask) | |
| x = self.proj(x) * x_mask | |
| if not reverse: | |
| flows = self.flows | |
| assert w is not None | |
| logdet_tot_q = 0 | |
| h_w = self.post_pre(w) | |
| h_w = self.post_convs(h_w, x_mask) | |
| h_w = self.post_proj(h_w) * x_mask | |
| e_q = ( | |
| torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) | |
| * x_mask | |
| ) | |
| z_q = e_q | |
| for flow in self.post_flows: | |
| z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) | |
| logdet_tot_q += logdet_q | |
| z_u, z1 = torch.split(z_q, [1, 1], 1) | |
| u = torch.sigmoid(z_u) * x_mask | |
| z0 = (w - u) * x_mask | |
| logdet_tot_q += torch.sum( | |
| (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2] | |
| ) | |
| logq = ( | |
| torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) | |
| - logdet_tot_q | |
| ) | |
| logdet_tot = 0 | |
| z0, logdet = self.log_flow(z0, x_mask) | |
| logdet_tot += logdet | |
| z = torch.cat([z0, z1], 1) | |
| for flow in flows: | |
| z, logdet = flow(z, x_mask, g=x, reverse=reverse) | |
| logdet_tot = logdet_tot + logdet | |
| nll = ( | |
| torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) | |
| - logdet_tot | |
| ) | |
| return nll + logq | |
| else: | |
| flows = list(reversed(self.flows)) | |
| flows = flows[:-2] + [flows[-1]] | |
| z = ( | |
| torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) | |
| * noise_scale | |
| ) | |
| for flow in flows: | |
| z = flow(z, x_mask, g=x, reverse=reverse) | |
| z0, z1 = torch.split(z, [1, 1], 1) | |
| logw = z0 | |
| return logw | |