|
import jax.numpy as jnp |
|
from flax import linen as nn |
|
from typing import Any |
|
|
|
import jax |
|
from .utils import custom_uniform |
|
from jax.nn.initializers import Initializer |
|
|
|
|
|
def complex_kernel_uniform_init(numerator : float = 6, |
|
mode : str = "fan_in", |
|
dtype : jnp.dtype = jnp.float32, |
|
distribution: str = "uniform") -> Initializer: |
|
def init(key: jax.random.key, shape: tuple, dtype: Any = dtype) -> Any: |
|
real_kernel = custom_uniform(numerator=numerator, mode=mode, distribution=distribution)(key, shape, dtype) |
|
imag_kernel = custom_uniform(numerator=numerator, mode=mode, distribution=distribution)(key, shape, dtype) |
|
|
|
return real_kernel + 1j * imag_kernel |
|
|
|
return init |
|
|
|
|
|
class WIRE(nn.Module): |
|
output_dim: int |
|
hidden_dim: int |
|
num_layers: int |
|
hidden_omega_0: float |
|
first_omega_0: float |
|
scale: float |
|
complexgabor: bool = False |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
if self.complexgabor: |
|
WIRElayer = ComplexGaborLayer |
|
dtype = jnp.complex64 |
|
else: |
|
WIRElayer = RealGaborLayer |
|
dtype = self.dtype |
|
self.kernel_net = [ |
|
WIRElayer( |
|
output_dim=self.hidden_dim, |
|
omega_0=self.first_omega_0, |
|
s_0=self.scale, |
|
is_first_layer=True, |
|
dtype=dtype |
|
) |
|
] + [ |
|
WIRElayer( |
|
output_dim=self.hidden_dim, |
|
omega_0=self.hidden_omega_0, |
|
s_0=self.scale, |
|
is_first_layer=False, |
|
dtype=dtype |
|
) |
|
for _ in range(self.num_layers) |
|
] |
|
|
|
self.output_linear = nn.Dense( |
|
features=self.output_dim, |
|
use_bias=True, |
|
kernel_init=custom_uniform(numerator=1, mode="fan_in", distribution="normal"), |
|
param_dtype=self.dtype, |
|
) |
|
|
|
def __call__(self, x): |
|
for layer in self.kernel_net: |
|
x = layer(x) |
|
|
|
out = jnp.real(self.output_linear(x)) |
|
|
|
return out |
|
|
|
|
|
class ComplexGaborLayer(nn.Module): |
|
output_dim: int |
|
omega_0: float |
|
s_0: float |
|
is_first_layer: bool = False |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
c = 1 if self.is_first_layer else 6 / self.omega_0**2 |
|
distrib = "uniform_squared" if self.is_first_layer else "uniform" |
|
|
|
if self.is_first_layer: |
|
dtype = self.dtype |
|
else: |
|
dtype = jnp.complex64 |
|
|
|
self.linear = nn.Dense( |
|
features=self.output_dim, |
|
use_bias=True, |
|
kernel_init=complex_kernel_uniform_init(numerator=c, mode="fan_in", distribution=distrib), |
|
param_dtype=dtype |
|
) |
|
|
|
def __call__(self, x): |
|
omega = self.omega_0 * self.linear(x) |
|
scale = self.s_0 * self.linear(x) |
|
|
|
return jnp.exp(1j * omega - (jnp.abs(scale)**2)) |
|
|
|
|
|
class RealGaborLayer(nn.Module): |
|
output_dim: int |
|
omega_0: float |
|
s_0: float |
|
is_first_layer: bool = False |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
|
|
c = 1 if self.is_first_layer else 6 / self.omega_0**2 |
|
distrib = "uniform_squared" if self.is_first_layer else "uniform" |
|
|
|
self.freqs = nn.Dense( |
|
features=self.output_dim, |
|
kernel_init=custom_uniform(numerator=c, mode="fan_in", distribution=distrib, dtype=self.dtype), |
|
use_bias=True, |
|
param_dtype=self.dtype |
|
) |
|
|
|
self.scales = nn.Dense( |
|
features = self.output_dim, |
|
kernel_init=custom_uniform(numerator=c, mode="fan_in", distribution=distrib, dtype=self.dtype), |
|
use_bias=True, |
|
param_dtype=self.dtype |
|
) |
|
|
|
def __call__(self, x): |
|
omega = self.omega_0 * self.freqs(x) |
|
scale = self.s_0 * self.scales(x) |
|
|
|
return jnp.cos(omega) * jnp.exp(-(scale**2)) |
|
|