Spaces:
Paused
Paused
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| #################### Norm2D for Discriminators #################### | |
| import torch | |
| import torch.nn as nn | |
| import einops | |
| from torch.nn.utils import spectral_norm, weight_norm | |
| CONV_NORMALIZATIONS = frozenset( | |
| [ | |
| "none", | |
| "weight_norm", | |
| "spectral_norm", | |
| "time_layer_norm", | |
| "layer_norm", | |
| "time_group_norm", | |
| ] | |
| ) | |
| class ConvLayerNorm(nn.LayerNorm): | |
| """ | |
| Convolution-friendly LayerNorm that moves channels to last dimensions | |
| before running the normalization and moves them back to original position right after. | |
| """ | |
| def __init__(self, normalized_shape, **kwargs): | |
| super().__init__(normalized_shape, **kwargs) | |
| def forward(self, x): | |
| x = einops.rearrange(x, "b ... t -> b t ...") | |
| x = super().forward(x) | |
| x = einops.rearrange(x, "b t ... -> b ... t") | |
| return | |
| def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module: | |
| assert norm in CONV_NORMALIZATIONS | |
| if norm == "weight_norm": | |
| return weight_norm(module) | |
| elif norm == "spectral_norm": | |
| return spectral_norm(module) | |
| else: | |
| # We already check was in CONV_NORMALIZATION, so any other choice | |
| # doesn't need reparametrization. | |
| return module | |
| def get_norm_module( | |
| module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs | |
| ) -> nn.Module: | |
| """Return the proper normalization module. If causal is True, this will ensure the returned | |
| module is causal, or return an error if the normalization doesn't support causal evaluation. | |
| """ | |
| assert norm in CONV_NORMALIZATIONS | |
| if norm == "layer_norm": | |
| assert isinstance(module, nn.modules.conv._ConvNd) | |
| return ConvLayerNorm(module.out_channels, **norm_kwargs) | |
| elif norm == "time_group_norm": | |
| if causal: | |
| raise ValueError("GroupNorm doesn't support causal evaluation.") | |
| assert isinstance(module, nn.modules.conv._ConvNd) | |
| return nn.GroupNorm(1, module.out_channels, **norm_kwargs) | |
| else: | |
| return nn.Identity() | |
| class NormConv2d(nn.Module): | |
| """Wrapper around Conv2d and normalization applied to this conv | |
| to provide a uniform interface across normalization approaches. | |
| """ | |
| def __init__( | |
| self, | |
| *args, | |
| norm: str = "none", | |
| norm_kwargs={}, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) | |
| self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) | |
| self.norm_type = norm | |
| def forward(self, x): | |
| x = self.conv(x) | |
| x = self.norm(x) | |
| return x | |