import torch from torch import nn, sin, pow from torch.nn import Parameter from torch.distributions.exponential import Exponential class Snake(nn.Module): ''' Implementation of the serpentine-like sine-based periodic activation function: .. math:: Snake_a := x + \frac{1}{a} sin^2(ax) = x - \frac{1}{2a}cos{2ax} + \frac{1}{2a} This activation function is able to better extrapolate to previously unseen data, especially in the case of learning periodic functions Shape: - Input: (N, *) where * means, any number of additional dimensions - Output: (N, *), same shape as the input Parameters: - a - trainable parameter References: - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: https://arxiv.org/abs/2006.08195 Examples: >>> a1 = snake(256) >>> x = torch.randn(256) >>> x = a1(x) ''' def __init__(self, in_features, a=None, trainable=True): ''' Initialization. Args: in_features: shape of the input a: trainable parameter trainable: sets `a` as a trainable parameter `a` is initialized to 1 by default, higher values = higher-frequency, 5-50 is a good starting point if you already think your data is periodic, consider starting lower e.g. 0.5 if you think not, but don't worry, `a` will be trained along with the rest of your model ''' super(Snake, self).__init__() self.in_features = in_features if isinstance(in_features, list) else [in_features] # Initialize `a` if a is not None: self.a = Parameter(torch.ones(self.in_features) * a) # create a tensor out of alpha else: m = Exponential(torch.tensor([0.1])) self.a = Parameter((m.rsample(self.in_features)).squeeze()) # random init = mix of frequencies self.a.requiresGrad = trainable # set the training of `a` to true def extra_repr(self) -> str: return 'in_features={}'.format(self.in_features) def forward(self, x): ''' Forward pass of the function. Applies the function to the input elementwise. Snake ∶= x + 1/a* sin^2 (xa) ''' return x + (1.0 / self.a) * pow(sin(x * self.a), 2)