|
|
|
|
|
import math |
|
|
|
import torch |
|
|
|
|
|
def parallel_stabilized_simple( |
|
queries: torch.Tensor, |
|
keys: torch.Tensor, |
|
values: torch.Tensor, |
|
igate_preact: torch.Tensor, |
|
fgate_preact: torch.Tensor, |
|
lower_triangular_matrix: torch.Tensor = None, |
|
stabilize_rowwise: bool = True, |
|
eps: float = 1e-6, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
"""This is the mLSTM cell in parallel form. |
|
This version is stabilized. We control the range of exp() arguments by |
|
ensuring that they are always smaller than 0.0 by subtracting the maximum. |
|
|
|
Args: |
|
queries (torch.Tensor): (B, NH, S, DH) |
|
keys (torch.Tensor): (B, NH, S, DH) |
|
values (torch.Tensor): (B, NH, S, DH) |
|
igate_preact (torch.Tensor): (B, NH, S, 1) |
|
fgate_preact (torch.Tensor): (B, NH, S, 1) |
|
lower_triangular_matrix (torch.Tensor, optional): (S,S). Defaults to None. |
|
stabilize_rowwise (bool, optional): Wether to stabilize the combination matrix C rowwise (take maximum per row). |
|
Alternative: Subtract the maximum over all rows. Defaults to True. |
|
|
|
Returns: |
|
torch.Tensor: (B, NH, S, DH), h_tilde_state |
|
""" |
|
|
|
B, NH, S, DH = queries.shape |
|
_dtype, _device = queries.dtype, queries.device |
|
|
|
|
|
log_fgates = torch.nn.functional.logsigmoid(fgate_preact) |
|
if lower_triangular_matrix is None or S < lower_triangular_matrix.size(-1): |
|
ltr = torch.tril(torch.ones((S, S), dtype=torch.bool, device=_device)) |
|
else: |
|
ltr = lower_triangular_matrix |
|
assert ltr.dtype == torch.bool, f"lower_triangular_matrix must be of dtype bool, got {ltr.dtype}" |
|
|
|
log_fgates_cumsum = torch.cat( |
|
[ |
|
torch.zeros((B, NH, 1, 1), dtype=_dtype, device=_device), |
|
torch.cumsum(log_fgates, dim=-2), |
|
], |
|
dim=-2, |
|
) |
|
|
|
|
|
|
|
rep_log_fgates_cumsum = log_fgates_cumsum.repeat(1, 1, 1, S + 1) |
|
|
|
|
|
_log_fg_matrix = rep_log_fgates_cumsum - rep_log_fgates_cumsum.transpose(-2, -1) |
|
|
|
|
|
log_fg_matrix = torch.where(ltr, _log_fg_matrix[:, :, 1:, 1:], -float("inf")) |
|
|
|
|
|
log_D_matrix = log_fg_matrix + igate_preact.transpose(-2, -1) |
|
|
|
if stabilize_rowwise: |
|
max_log_D, _ = torch.max(log_D_matrix, dim=-1, keepdim=True) |
|
else: |
|
max_log_D = torch.max(log_D_matrix.view(B, NH, -1), dim=-1, keepdim=True)[0].unsqueeze(-1) |
|
|
|
log_D_matrix_stabilized = log_D_matrix - max_log_D |
|
D_matrix = torch.exp(log_D_matrix_stabilized) |
|
|
|
keys_scaled = keys / math.sqrt(DH) |
|
|
|
|
|
qk_matrix = queries @ keys_scaled.transpose(-2, -1) |
|
C_matrix = qk_matrix * D_matrix |
|
normalizer = torch.maximum(C_matrix.sum(dim=-1, keepdim=True).abs(), torch.exp(-max_log_D)) |
|
|
|
C_matrix_normalized = C_matrix / (normalizer + eps) |
|
|
|
|
|
h_tilde_state = C_matrix_normalized @ values |
|
|
|
return h_tilde_state |
|
|
|
|
|
def recurrent_step_stabilized_simple( |
|
c_state: torch.Tensor, |
|
n_state: torch.Tensor, |
|
m_state: torch.Tensor, |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
igate_preact: torch.Tensor, |
|
fgate_preact: torch.Tensor, |
|
eps: float = 1e-6, |
|
**kwargs, |
|
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: |
|
"""This is a single step of the mLSTM operation in recurrent form. |
|
|
|
Args: |
|
c_state (torch.Tensor): (B, NH, DH, DH) |
|
n_state (torch.Tensor): (B, NH, DH, 1) |
|
m_state (torch.Tensor): (B, NH, 1, 1) |
|
q (torch.Tensor): (B, NH, 1, DH) |
|
k (torch.Tensor): (B, NH, 1, DH) |
|
v (torch.Tensor): (B, NH, 1, DH) |
|
igate_preact (torch.Tensor): (B, NH, 1, 1) |
|
fgate_preact (torch.Tensor): (B, NH, 1, 1) |
|
|
|
Returns: |
|
tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: |
|
(hidden_state [B, NH, DH], (c_state_new [B, NH, DH, DH], n_state_new [B, NH, DH, 1]], m_state_new [B, NH, 1, 1])) |
|
""" |
|
B, NH, S, DH = q.shape |
|
|
|
q, k, v = q.squeeze_(2).unsqueeze(-1), k.squeeze_(2).unsqueeze(-1), v.squeeze_(2).unsqueeze(-1) |
|
|
|
|
|
log_fg_act = torch.nn.functional.logsigmoid(fgate_preact) |
|
|
|
|
|
m_state_new = torch.max(log_fg_act + m_state, igate_preact) |
|
|
|
fg_act = torch.exp(log_fg_act + m_state - m_state_new) |
|
ig_act = torch.exp(igate_preact - m_state_new) |
|
|
|
k_scaled = k / math.sqrt(DH) |
|
|
|
c_state_new = fg_act * c_state + ig_act * (k_scaled @ v.transpose(-1, -2)) |
|
n_state_new = fg_act * n_state + ig_act * k_scaled |
|
|
|
h_num = q.transpose(-1, -2) @ c_state_new |
|
|
|
qn_dotproduct = q.transpose(-1, -2) @ n_state_new |
|
max_val = torch.exp(-m_state_new) |
|
h_denom = torch.maximum(qn_dotproduct.abs(), max_val) + eps |
|
h = h_num / h_denom |
|
|
|
return h, (c_state_new, n_state_new, m_state_new) |
|
|