|
import jax |
|
import math |
|
from typing import Any, Dict, Sequence, Union |
|
|
|
import jax.numpy as jnp |
|
from jax import dtypes, random |
|
from jax.nn.initializers import Initializer |
|
from typing import Callable |
|
from flax import linen as nn |
|
|
|
class FourierEmbs(nn.Module): |
|
embed_scale: float |
|
embed_dim: int |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
kernel = self.param( |
|
"kernel", jax.nn.initializers.normal(self.embed_scale, dtype=self.dtype), (x.shape[-1], self.embed_dim // 2) |
|
) |
|
y = jnp.concatenate( |
|
[jnp.cos(jnp.dot(x, kernel)), jnp.sin(jnp.dot(x, kernel))], axis=-1 |
|
) |
|
return y |
|
|
|
def _weight_fact(init_fn, mean, stddev, dtype=jnp.float32): |
|
def init(key, shape): |
|
key1, key2 = jax.random.split(key) |
|
w = init_fn(key1, shape) |
|
g = mean + nn.initializers.normal(stddev, dtype=dtype)(key2, (shape[-1],)) |
|
g = jnp.exp(g) |
|
v = w / g |
|
return g, v |
|
|
|
return init |
|
|
|
class Dense(nn.Module): |
|
features: int |
|
kernel_init: Callable = nn.initializers.glorot_normal() |
|
bias_init: Callable = nn.initializers.zeros |
|
reparam : Union[None, Dict] = None |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
if self.reparam is None: |
|
kernel = self.param( |
|
"kernel", self.kernel_init(dtype=self.dtype), (x.shape[-1], self.features) |
|
) |
|
elif self.reparam["type"] == "weight_fact": |
|
g, v = self.param( |
|
"kernel", |
|
_weight_fact( |
|
self.kernel_init, |
|
mean=self.reparam["mean"], |
|
stddev=self.reparam["stddev"], |
|
dtype=self.dtype |
|
), |
|
(x.shape[-1], self.features), |
|
) |
|
|
|
kernel = g * v |
|
|
|
bias = self.param("bias", self.bias_init(dtype=self.dtype), (self.features,)) |
|
|
|
y = jnp.dot(x, kernel) + bias |
|
|
|
return y |
|
|
|
|
|
def _compute_fans( |
|
shape: tuple, |
|
in_axis: Union[int, Sequence[int]] = -2, |
|
out_axis: Union[int, Sequence[int]] = -1, |
|
batch_axis: Union[int, Sequence[int]] = (), |
|
): |
|
"""Compute effective input and output sizes for a linear or convolutional layer. |
|
|
|
Axes not in in_axis, out_axis, or batch_axis are assumed to constitute the "receptive field" of |
|
a convolution (kernel spatial dimensions). |
|
""" |
|
if len(shape) <= 1: |
|
raise ValueError( |
|
f"Can't compute input and output sizes of a {shape.rank}" |
|
"-dimensional weights tensor. Must be at least 2D." |
|
) |
|
|
|
if isinstance(in_axis, int): |
|
in_size = shape[in_axis] |
|
else: |
|
in_size = math.prod([shape[i] for i in in_axis]) |
|
if isinstance(out_axis, int): |
|
out_size = shape[out_axis] |
|
else: |
|
out_size = math.prod([shape[i] for i in out_axis]) |
|
if isinstance(batch_axis, int): |
|
batch_size = shape[batch_axis] |
|
else: |
|
batch_size = math.prod([shape[i] for i in batch_axis]) |
|
receptive_field_size = math.prod(shape) / in_size / out_size / batch_size |
|
fan_in = in_size * receptive_field_size |
|
fan_out = out_size * receptive_field_size |
|
return fan_in, fan_out |
|
|
|
|
|
def custom_uniform( |
|
numerator: float = 6, |
|
mode: str = "fan_in", |
|
dtype: jnp.dtype = jnp.float32, |
|
in_axis: Union[int, Sequence[int]] = -2, |
|
out_axis: Union[int, Sequence[int]] = -1, |
|
batch_axis: Sequence[int] = (), |
|
distribution: str = "uniform", |
|
) -> Initializer: |
|
"""Builds an initializer that returns real uniformly-distributed random arrays. |
|
|
|
:param numerator: the numerator of the range of the random distribution. |
|
:type numerator: float |
|
:param mode: the mode for computing the range of the random distribution. |
|
:type mode: str |
|
:param dtype: optional; the initializer's default dtype. |
|
:type dtype: jnp.dtype |
|
:param in_axis: the axis or axes that specify the input size. |
|
:type in_axis: Union[int, Sequence[int]] |
|
:param out_axis: the axis or axes that specify the output size. |
|
:type out_axis: Union[int, Sequence[int]] |
|
:param batch_axis: the axis or axes that specify the batch size. |
|
:type batch_axis: Sequence[int] |
|
:param distribution: the distribution of the random distribution. |
|
:type distribution: str |
|
|
|
:return: An initializer that returns arrays whose values are uniformly distributed in |
|
the range ``[-range, range)``. |
|
:rtype: Initializer |
|
""" |
|
|
|
def init(key: jax.random.key, shape: tuple, dtype: Any = dtype) -> Any: |
|
dtype = dtypes.canonicalize_dtype(dtype) |
|
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis, batch_axis) |
|
if mode == "fan_in": |
|
denominator = fan_in |
|
elif mode == "fan_out": |
|
denominator = fan_out |
|
elif mode == "fan_avg": |
|
denominator = (fan_in + fan_out) / 2 |
|
else: |
|
raise ValueError(f"invalid mode for variance scaling initializer: {mode}") |
|
if distribution == "uniform": |
|
return random.uniform( |
|
key, |
|
shape, |
|
dtype, |
|
minval=-jnp.sqrt(numerator / denominator), |
|
maxval=jnp.sqrt(numerator / denominator), |
|
) |
|
elif distribution == "normal": |
|
return random.normal(key, shape, dtype) * jnp.sqrt(numerator / denominator) |
|
elif distribution == "uniform_squared": |
|
return random.uniform( |
|
key, shape, dtype, minval=-numerator / denominator, maxval=numerator / denominator |
|
) |
|
else: |
|
raise ValueError( |
|
f"invalid distribution for variance scaling initializer: {distribution}" |
|
) |
|
|
|
return init |
|
|