sereich's picture
Initial commit of Radio Upscaling UI (minus models)
f113387
raw
history blame
2.39 kB
import torch
from torch import nn, sin, pow
from torch.nn import Parameter
from torch.distributions.exponential import Exponential
class Snake(nn.Module):
'''
Implementation of the serpentine-like sine-based periodic activation function:
.. math::
Snake_a := x + \frac{1}{a} sin^2(ax) = x - \frac{1}{2a}cos{2ax} + \frac{1}{2a}
This activation function is able to better extrapolate to previously unseen data,
especially in the case of learning periodic functions
Shape:
- Input: (N, *) where * means, any number of additional
dimensions
- Output: (N, *), same shape as the input
Parameters:
- a - trainable parameter
References:
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snake(256)
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, a=None, trainable=True):
'''
Initialization.
Args:
in_features: shape of the input
a: trainable parameter
trainable: sets `a` as a trainable parameter
`a` is initialized to 1 by default, higher values = higher-frequency,
5-50 is a good starting point if you already think your data is periodic,
consider starting lower e.g. 0.5 if you think not, but don't worry,
`a` will be trained along with the rest of your model
'''
super(Snake, self).__init__()
self.in_features = in_features if isinstance(in_features, list) else [in_features]
# Initialize `a`
if a is not None:
self.a = Parameter(torch.ones(self.in_features) * a) # create a tensor out of alpha
else:
m = Exponential(torch.tensor([0.1]))
self.a = Parameter((m.rsample(self.in_features)).squeeze()) # random init = mix of frequencies
self.a.requiresGrad = trainable # set the training of `a` to true
def extra_repr(self) -> str:
return 'in_features={}'.format(self.in_features)
def forward(self, x):
'''
Forward pass of the function.
Applies the function to the input elementwise.
Snake ∶= x + 1/a* sin^2 (xa)
'''
return x + (1.0 / self.a) * pow(sin(x * self.a), 2)