AndreiB137's picture
update file structure
2da731b
import jax.numpy as jnp
from flax import linen as nn
from typing import Any
import jax
from .utils import custom_uniform
from jax.nn.initializers import Initializer
def complex_kernel_uniform_init(numerator : float = 6,
mode : str = "fan_in",
dtype : jnp.dtype = jnp.float32,
distribution: str = "uniform") -> Initializer:
def init(key: jax.random.key, shape: tuple, dtype: Any = dtype) -> Any:
real_kernel = custom_uniform(numerator=numerator, mode=mode, distribution=distribution)(key, shape, dtype)
imag_kernel = custom_uniform(numerator=numerator, mode=mode, distribution=distribution)(key, shape, dtype)
return real_kernel + 1j * imag_kernel
return init
class WIRE(nn.Module):
output_dim: int
hidden_dim: int
num_layers: int
hidden_omega_0: float
first_omega_0: float
scale: float
complexgabor: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
if self.complexgabor:
WIRElayer = ComplexGaborLayer
dtype = jnp.complex64
else:
WIRElayer = RealGaborLayer
dtype = self.dtype
self.kernel_net = [
WIRElayer(
output_dim=self.hidden_dim,
omega_0=self.first_omega_0,
s_0=self.scale,
is_first_layer=True,
dtype=dtype
)
] + [
WIRElayer(
output_dim=self.hidden_dim,
omega_0=self.hidden_omega_0,
s_0=self.scale,
is_first_layer=False,
dtype=dtype
)
for _ in range(self.num_layers)
]
self.output_linear = nn.Dense(
features=self.output_dim,
use_bias=True,
kernel_init=custom_uniform(numerator=1, mode="fan_in", distribution="normal"),
param_dtype=self.dtype,
)
def __call__(self, x):
for layer in self.kernel_net:
x = layer(x)
out = jnp.real(self.output_linear(x))
return out
class ComplexGaborLayer(nn.Module):
output_dim: int
omega_0: float
s_0: float
is_first_layer: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
c = 1 if self.is_first_layer else 6 / self.omega_0**2
distrib = "uniform_squared" if self.is_first_layer else "uniform"
if self.is_first_layer:
dtype = self.dtype
else:
dtype = jnp.complex64
self.linear = nn.Dense(
features=self.output_dim,
use_bias=True,
kernel_init=complex_kernel_uniform_init(numerator=c, mode="fan_in", distribution=distrib),
param_dtype=dtype
)
def __call__(self, x):
omega = self.omega_0 * self.linear(x)
scale = self.s_0 * self.linear(x)
return jnp.exp(1j * omega - (jnp.abs(scale)**2))
class RealGaborLayer(nn.Module):
output_dim: int
omega_0: float
s_0: float
is_first_layer: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
c = 1 if self.is_first_layer else 6 / self.omega_0**2
distrib = "uniform_squared" if self.is_first_layer else "uniform"
self.freqs = nn.Dense(
features=self.output_dim,
kernel_init=custom_uniform(numerator=c, mode="fan_in", distribution=distrib, dtype=self.dtype),
use_bias=True,
param_dtype=self.dtype
)
self.scales = nn.Dense(
features = self.output_dim,
kernel_init=custom_uniform(numerator=c, mode="fan_in", distribution=distrib, dtype=self.dtype),
use_bias=True,
param_dtype=self.dtype
)
def __call__(self, x):
omega = self.omega_0 * self.freqs(x)
scale = self.s_0 * self.scales(x)
return jnp.cos(omega) * jnp.exp(-(scale**2))