zaydzuhri commited on
Commit
5379428
·
verified ·
1 Parent(s): 652030e

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. config.json +33 -0
  2. fla/layers/__init__.py +44 -0
  3. fla/layers/attn.py +243 -0
  4. fla/layers/bitattn.py +192 -0
  5. fla/layers/gated_deltaproduct.py +351 -0
  6. fla/layers/hgrn.py +168 -0
  7. fla/layers/hgrn2.py +211 -0
  8. fla/layers/lightnet.py +210 -0
  9. fla/layers/linear_attn.py +166 -0
  10. fla/layers/multiscale_retention.py +298 -0
  11. fla/layers/nsa.py +138 -0
  12. fla/layers/rwkv6.py +307 -0
  13. fla/layers/simple_gla.py +261 -0
  14. fla/models/__init__.py +51 -0
  15. fla/models/utils.py +147 -0
  16. fla/modules/activations.py +471 -0
  17. fla/modules/fused_cross_entropy.py +419 -0
  18. fla/modules/fused_norm_gate.py +995 -0
  19. fla/modules/grpo.py +396 -0
  20. fla/modules/l2norm.py +176 -0
  21. fla/modules/layernorm_gated.py +528 -0
  22. fla/modules/mlp.py +127 -0
  23. fla/ops/attn/__init__.py +17 -0
  24. fla/ops/attn/__pycache__/parallel_softpick.cpython-312.pyc +0 -0
  25. fla/ops/attn/naive_softpick.py +39 -0
  26. fla/ops/based/__init__.py +9 -0
  27. fla/ops/common/__init__.py +1 -0
  28. fla/ops/common/chunk_delta_h.py +399 -0
  29. fla/ops/common/chunk_h.py +422 -0
  30. fla/ops/common/chunk_h_split.py +677 -0
  31. fla/ops/common/chunk_o.py +668 -0
  32. fla/ops/common/chunk_scaled_dot_kkt.py +126 -0
  33. fla/ops/common/utils.py +69 -0
  34. fla/ops/delta_rule/README.md +90 -0
  35. fla/ops/delta_rule/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  36. fla/ops/delta_rule/chunk.py +373 -0
  37. fla/ops/delta_rule/fused_recurrent.py +607 -0
  38. fla/ops/forgetting_attn/__init__.py +7 -0
  39. fla/ops/gated_delta_rule/__pycache__/chunk.cpython-312.pyc +0 -0
  40. fla/ops/gated_delta_rule/chunk.py +392 -0
  41. fla/ops/generalized_delta_rule/README.md +37 -0
  42. fla/ops/generalized_delta_rule/__init__.py +9 -0
  43. fla/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-312.pyc +0 -0
  44. fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-312.pyc +0 -0
  45. fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-312.pyc +0 -0
  46. fla/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-312.pyc +0 -0
  47. fla/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-312.pyc +0 -0
  48. fla/ops/generalized_delta_rule/dplr/naive.py +96 -0
  49. fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py +184 -0
  50. fla/ops/gla/fused_recurrent.py +113 -0
config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "TransformerForCausalLM"
4
+ ],
5
+ "attn_impl": "parallel_attn",
6
+ "bos_token_id": 1,
7
+ "elementwise_affine": true,
8
+ "eos_token_id": 2,
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "fuse_swiglu": true,
12
+ "hidden_act": "swish",
13
+ "hidden_ratio": 4,
14
+ "hidden_size": 2048,
15
+ "initializer_range": 0.006,
16
+ "intermediate_size": null,
17
+ "max_position_embeddings": 8192,
18
+ "model_type": "transformer",
19
+ "norm_eps": 1e-06,
20
+ "num_heads": 32,
21
+ "num_hidden_layers": 32,
22
+ "num_kv_heads": null,
23
+ "pad_token_id": 2,
24
+ "qk_norm": false,
25
+ "qkv_bias": false,
26
+ "rope_theta": 10000.0,
27
+ "tie_word_embeddings": false,
28
+ "torch_dtype": "float32",
29
+ "transformers_version": "4.51.3",
30
+ "use_cache": true,
31
+ "vocab_size": 32000,
32
+ "window_size": null
33
+ }
fla/layers/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from .abc import ABCAttention
5
+ from .attn import Attention
6
+ from .based import BasedLinearAttention
7
+ from .bitattn import BitAttention
8
+ from .delta_net import DeltaNet
9
+ from .forgetting_attn import ForgettingAttention
10
+ from .gated_deltanet import GatedDeltaNet
11
+ from .gated_deltaproduct import GatedDeltaProduct
12
+ from .gla import GatedLinearAttention
13
+ from .gsa import GatedSlotAttention
14
+ from .hgrn import HGRNAttention
15
+ from .hgrn2 import HGRN2Attention
16
+ from .lightnet import LightNetAttention
17
+ from .linear_attn import LinearAttention
18
+ from .multiscale_retention import MultiScaleRetention
19
+ from .nsa import NativeSparseAttention
20
+ from .rebased import ReBasedLinearAttention
21
+ from .rwkv6 import RWKV6Attention
22
+ from .rwkv7 import RWKV7Attention
23
+
24
+ __all__ = [
25
+ 'ABCAttention',
26
+ 'Attention',
27
+ 'BasedLinearAttention',
28
+ 'BitAttention',
29
+ 'DeltaNet',
30
+ 'ForgettingAttention',
31
+ 'GatedDeltaNet',
32
+ 'GatedDeltaProduct',
33
+ 'GatedLinearAttention',
34
+ 'GatedSlotAttention',
35
+ 'HGRNAttention',
36
+ 'HGRN2Attention',
37
+ 'LightNetAttention',
38
+ 'LinearAttention',
39
+ 'MultiScaleRetention',
40
+ 'NativeSparseAttention',
41
+ 'ReBasedLinearAttention',
42
+ 'RWKV6Attention',
43
+ 'RWKV7Attention',
44
+ ]
fla/layers/attn.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from einops import rearrange
14
+ from transformers.utils import logging
15
+
16
+ from fla.modules import RMSNorm, RotaryEmbedding
17
+ from fla.ops import parallel_attn, parallel_rectified_attn, parallel_softpick_attn, naive_attn, naive_rectified_attn, naive_softpick_attn
18
+
19
+ if TYPE_CHECKING:
20
+ from fla.models.utils import Cache
21
+
22
+ try:
23
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
24
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
25
+ except ImportError:
26
+ warnings.warn(
27
+ "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
28
+ category=ImportWarning
29
+ )
30
+ flash_attn_func = None
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class Attention(nn.Module):
36
+
37
+ def __init__(
38
+ self,
39
+ hidden_size: int = 2048,
40
+ num_heads: int = 32,
41
+ num_kv_heads: Optional[int] = None,
42
+ qkv_bias: bool = False,
43
+ qk_norm: bool = False,
44
+ window_size: Optional[int] = None,
45
+ rope_theta: Optional[float] = 10000.,
46
+ max_position_embeddings: Optional[int] = None,
47
+ layer_idx: int = None,
48
+ attn_impl: str = "flash_attn",
49
+ ):
50
+ super().__init__()
51
+
52
+ self.hidden_size = hidden_size
53
+ self.num_heads = num_heads
54
+ if num_kv_heads is None:
55
+ self.num_kv_heads = self.num_heads
56
+ else:
57
+ self.num_kv_heads = num_kv_heads
58
+ self.num_kv_groups = num_heads // self.num_kv_heads
59
+ self.head_dim = self.hidden_size // self.num_heads
60
+ self.kv_dim = self.num_kv_heads * self.head_dim
61
+ self.qkv_bias = qkv_bias
62
+ self.qk_norm = qk_norm
63
+
64
+ self.window_size = window_size
65
+ self.rope_theta = rope_theta
66
+ self.max_position_embeddings = max_position_embeddings
67
+ self.layer_idx = layer_idx
68
+ self.attn_impl = attn_impl
69
+
70
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias)
71
+ self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
72
+ self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
73
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
74
+
75
+ if "scaled" in self.attn_impl:
76
+ self.s = nn.Parameter(torch.empty(self.num_heads, 1))
77
+ self.register_buffer("logn", torch.log(torch.arange(2, self.max_position_embeddings*4+2, dtype=self.s.dtype)[:, None, None]))
78
+
79
+ if qk_norm:
80
+ self.q_norm = RMSNorm(self.head_dim)
81
+ self.k_norm = RMSNorm(self.head_dim)
82
+
83
+ self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
84
+
85
+ def reset_parameters(self):
86
+ if "scaled" in self.attn_impl:
87
+ nn.init.constant_(self.s, 0.3)
88
+ self.logn.copy_(torch.log(torch.arange(2, self.max_position_embeddings*4+2, dtype=self.s.dtype)[:, None, None]))
89
+
90
+ def forward(
91
+ self,
92
+ hidden_states: torch.Tensor,
93
+ attention_mask: Optional[torch.LongTensor] = None,
94
+ past_key_values: Optional[Cache] = None,
95
+ output_attentions: bool = False,
96
+ use_cache: bool = False,
97
+ **kwargs,
98
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
99
+ if attention_mask is not None:
100
+ assert len(attention_mask.shape) == 2, (
101
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
102
+ "for padding purposes (0 indicating padding). "
103
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
104
+ )
105
+
106
+ batch_size, q_len, _ = hidden_states.size()
107
+
108
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
109
+
110
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
111
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
112
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
113
+
114
+ if self.qk_norm:
115
+ q, k = self.q_norm(q), self.k_norm(k)
116
+
117
+ # equivalent to cu_seqlens in `flash_attn`
118
+ cu_seqlens = kwargs.get('cu_seqlens', None)
119
+
120
+ seqlen_offset, max_seqlen = 0, q_len
121
+ if past_key_values is not None:
122
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
123
+ max_seqlen = q.shape[1] + seqlen_offset
124
+
125
+ if attention_mask is not None:
126
+ # to deliminate the offsets of padding tokens
127
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
128
+ max_seqlen = q.shape[1] + max(seqlen_offset)
129
+
130
+ if self.max_position_embeddings is not None:
131
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
132
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
133
+
134
+ if past_key_values is not None:
135
+ cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0
136
+ k_cached, v_cached = past_key_values.update(
137
+ attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
138
+ layer_idx=self.layer_idx,
139
+ offset=q_len,
140
+ cache_kwargs=dict(window_size=self.window_size)
141
+ )['attn_state']
142
+ if cache_has_content:
143
+ k, v = k_cached, v_cached
144
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
145
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
146
+
147
+ # if flash_attn_func is None:
148
+ # raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
149
+
150
+ if "scaled" in self.attn_impl:
151
+ k_len = k.shape[1]
152
+ q = q * self.s.to(q.dtype) * self.logn[k_len-q_len:k_len].to(q.dtype)
153
+
154
+ # Contains at least one padding token in the sequence
155
+ if self.attn_impl == "flash_attn":
156
+ if attention_mask is not None:
157
+ q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q, k, v, attention_mask, q_len)
158
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
159
+ max_seqlen_q, max_seqlen_k = max_seq_lens
160
+ o = flash_attn_varlen_func(
161
+ q, k, v,
162
+ cu_seqlens_q=cu_seqlens_q,
163
+ cu_seqlens_k=cu_seqlens_k,
164
+ max_seqlen_q=max_seqlen_q,
165
+ max_seqlen_k=max_seqlen_k,
166
+ causal=True,
167
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
168
+ )
169
+ o = pad_input(o, indices_q, batch_size, q_len)
170
+ elif cu_seqlens is not None:
171
+ o = flash_attn_varlen_func(
172
+ q.squeeze(0), k.squeeze(0), v.squeeze(0),
173
+ cu_seqlens_q=cu_seqlens,
174
+ cu_seqlens_k=cu_seqlens,
175
+ max_seqlen_q=max_seqlen,
176
+ max_seqlen_k=max_seqlen,
177
+ causal=True,
178
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
179
+ ).unsqueeze(0)
180
+ else:
181
+ o = flash_attn_func(
182
+ q, k, v,
183
+ causal=True,
184
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
185
+ )
186
+ elif self.attn_impl == "parallel_attn":
187
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
188
+ elif self.attn_impl == "parallel_scaled_attn":
189
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
190
+ elif self.attn_impl == "parallel_rectified_attn":
191
+ o = parallel_rectified_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
192
+ elif self.attn_impl == "parallel_softpick_attn":
193
+ o = parallel_softpick_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
194
+ elif self.attn_impl == "parallel_scaled_softpick_attn":
195
+ o = parallel_softpick_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
196
+ elif self.attn_impl == "naive_attn":
197
+ o, attentions = naive_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
198
+ elif self.attn_impl == "naive_scaled_attn":
199
+ o, attentions = naive_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
200
+ elif self.attn_impl == "naive_rectified_attn":
201
+ o, attentions = naive_rectified_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
202
+ elif self.attn_impl == "naive_softpick_attn":
203
+ o, attentions = naive_softpick_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
204
+ elif self.attn_impl == "naive_scaled_softpick_attn":
205
+ o, attentions = naive_softpick_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
206
+ else:
207
+ raise ValueError(f"Unknown attention implementation: {self.attn_impl}")
208
+
209
+ o = o.reshape(batch_size, q_len, -1)
210
+ o = self.o_proj(o)
211
+
212
+ if not output_attentions or "parallel" in self.attn_impl or "flash" in self.attn_impl:
213
+ attentions = None
214
+
215
+ return o, attentions, past_key_values
216
+
217
+ def _upad_input(self, q, k, v, attention_mask, q_len):
218
+ batch_size, seq_len, num_key_value_heads, head_dim = k.shape
219
+ cache_mask = attention_mask[:, -seq_len:]
220
+ seqlens = cache_mask.sum(-1, dtype=torch.int32)
221
+ indices_k = torch.nonzero(cache_mask.flatten(), as_tuple=False).flatten()
222
+ max_seqlen_k = seqlens.max().item()
223
+ cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
224
+
225
+ k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
226
+ v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
227
+ if q_len == seq_len:
228
+ q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k)
229
+ cu_seqlens_q = cu_seqlens_k
230
+ max_seqlen_q = max_seqlen_k
231
+ indices_q = indices_k
232
+ elif q_len == 1:
233
+ max_seqlen_q = 1
234
+ # There is a memcpy here, that is very bad.
235
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
236
+ indices_q = cu_seqlens_q[:-1]
237
+ q = q.squeeze(1)
238
+ else:
239
+ # The -q_len: slice assumes left padding.
240
+ attention_mask = attention_mask[:, -q_len:]
241
+ q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask)
242
+
243
+ return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
fla/layers/bitattn.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from einops import rearrange
14
+ from transformers.utils import logging
15
+
16
+ from fla.modules import RotaryEmbedding
17
+ from fla.modules.fused_bitlinear import FusedBitLinear
18
+
19
+ if TYPE_CHECKING:
20
+ from fla.models.utils import Cache
21
+
22
+ try:
23
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
24
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
25
+ except ImportError:
26
+ warnings.warn(
27
+ "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
28
+ category=ImportWarning
29
+ )
30
+ flash_attn_func = None
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class BitAttention(nn.Module):
36
+
37
+ def __init__(
38
+ self,
39
+ hidden_size: int = 2048,
40
+ num_heads: int = 32,
41
+ num_kv_heads: Optional[int] = None,
42
+ window_size: Optional[int] = None,
43
+ rope_theta: Optional[float] = 10000.,
44
+ max_position_embeddings: Optional[int] = None,
45
+ norm_eps: float = 1e-5,
46
+ layer_idx: int = None
47
+ ):
48
+ super().__init__()
49
+
50
+ self.num_heads = num_heads
51
+ if num_kv_heads is None:
52
+ self.num_kv_heads = self.num_heads
53
+ else:
54
+ self.num_kv_heads = num_kv_heads
55
+ self.num_kv_groups = num_heads // self.num_kv_heads
56
+ self.hidden_size = hidden_size
57
+ self.head_dim = self.hidden_size // self.num_heads
58
+ self.kv_dim = self.num_kv_heads * self.head_dim
59
+ self.kv_dim = self.num_kv_heads * self.head_dim
60
+ self.window_size = window_size
61
+ self.rope_theta = rope_theta
62
+ self.max_position_embeddings = max_position_embeddings
63
+ self.layer_idx = layer_idx
64
+
65
+ self.q_proj = FusedBitLinear(self.hidden_size, self.hidden_size, bias=False)
66
+ self.k_proj = FusedBitLinear(self.hidden_size, self.kv_dim, bias=False)
67
+ self.v_proj = FusedBitLinear(self.hidden_size, self.kv_dim, bias=False)
68
+ self.o_proj = FusedBitLinear(self.hidden_size, self.hidden_size, bias=False)
69
+
70
+ self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: torch.Tensor,
75
+ attention_mask: Optional[torch.LongTensor] = None,
76
+ past_key_values: Optional[Cache] = None,
77
+ output_attentions: bool = False,
78
+ use_cache: bool = False,
79
+ **kwargs,
80
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
81
+ if attention_mask is not None:
82
+ assert len(attention_mask.shape) == 2, (
83
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
84
+ "for padding purposes (0 indicating padding). "
85
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
86
+ )
87
+
88
+ batch_size, q_len, _ = hidden_states.size()
89
+
90
+ q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
91
+ k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
92
+ v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
93
+
94
+ # equivalent to cu_seqlens in `flash_attn`
95
+ cu_seqlens = kwargs.get('cu_seqlens', None)
96
+
97
+ seqlen_offset, max_seqlen = 0, q_len
98
+ if past_key_values is not None:
99
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
100
+ max_seqlen = q.shape[1] + seqlen_offset
101
+
102
+ if attention_mask is not None:
103
+ # to deliminate the offsets of padding tokens
104
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
105
+ max_seqlen = q.shape[1] + max(seqlen_offset)
106
+
107
+ if self.max_position_embeddings is not None:
108
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
109
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
110
+
111
+ if past_key_values is not None:
112
+ cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0
113
+ k_cached, v_cached = past_key_values.update(
114
+ attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
115
+ layer_idx=self.layer_idx,
116
+ offset=q_len,
117
+ cache_kwargs=dict(window_size=self.window_size)
118
+ )['attn_state']
119
+ if cache_has_content:
120
+ k, v = k_cached, v_cached
121
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
122
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
123
+
124
+ if flash_attn_func is None:
125
+ raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
126
+
127
+ # Contains at least one padding token in the sequence
128
+ if attention_mask is not None:
129
+ q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q, k, v, attention_mask, q_len)
130
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
131
+ max_seqlen_q, max_seqlen_k = max_seq_lens
132
+ o = flash_attn_varlen_func(
133
+ q, k, v,
134
+ cu_seqlens_q=cu_seqlens_q,
135
+ cu_seqlens_k=cu_seqlens_k,
136
+ max_seqlen_q=max_seqlen_q,
137
+ max_seqlen_k=max_seqlen_k,
138
+ causal=True,
139
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
140
+ )
141
+ o = pad_input(o, indices_q, batch_size, q_len)
142
+ elif cu_seqlens is not None:
143
+ o = flash_attn_varlen_func(
144
+ q.squeeze(0), k.squeeze(0), v.squeeze(0),
145
+ cu_seqlens_q=cu_seqlens,
146
+ cu_seqlens_k=cu_seqlens,
147
+ max_seqlen_q=max_seqlen,
148
+ max_seqlen_k=max_seqlen,
149
+ causal=True,
150
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
151
+ ).unsqueeze(0)
152
+ else:
153
+ o = flash_attn_func(
154
+ q, k, v,
155
+ causal=True,
156
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
157
+ )
158
+ o = o.reshape(batch_size, q_len, -1)
159
+ o = self.o_proj(o)
160
+
161
+ if not output_attentions:
162
+ attentions = None
163
+
164
+ return o, attentions, past_key_values
165
+
166
+ def _upad_input(self, q, k, v, attention_mask, q_len):
167
+ batch_size, seq_len, num_key_value_heads, head_dim = k.shape
168
+ cache_mask = attention_mask[:, -seq_len:]
169
+ seqlens = cache_mask.sum(-1, dtype=torch.int32)
170
+ indices_k = torch.nonzero(cache_mask.flatten(), as_tuple=False).flatten()
171
+ max_seqlen_k = seqlens.max().item()
172
+ cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
173
+
174
+ k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
175
+ v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
176
+ if q_len == seq_len:
177
+ q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k)
178
+ cu_seqlens_q = cu_seqlens_k
179
+ max_seqlen_q = max_seqlen_k
180
+ indices_q = indices_k
181
+ elif q_len == 1:
182
+ max_seqlen_q = 1
183
+ # There is a memcpy here, that is very bad.
184
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
185
+ indices_q = cu_seqlens_q[:-1]
186
+ q = q.squeeze(1)
187
+ else:
188
+ # The -q_len: slice assumes left padding.
189
+ attention_mask = attention_mask[:, -q_len:]
190
+ q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask)
191
+
192
+ return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
fla/layers/gated_deltaproduct.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+
11
+ from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
12
+ from fla.ops.delta_rule import chunk_delta_rule
13
+ from fla.ops.gated_delta_rule import chunk_gated_delta_rule
14
+
15
+ if TYPE_CHECKING:
16
+ from transformers.processing_utils import Unpack
17
+
18
+ from fla.models.utils import Cache
19
+
20
+
21
+ def elu_p1(x):
22
+ return (F.elu(x, 1.0, False) + 1.0).to(x)
23
+
24
+
25
+ def sum_norm(x):
26
+ return (x / x.sum(-1, keepdim=True)).to(x)
27
+
28
+
29
+ def interleave_multiple_sequences(*sequences):
30
+ """
31
+ Interleave multiple sequences together.
32
+ For example, with sequences [A1, A2], [B1, B2], [C1, C2],
33
+ returns [A1, B1, C1, A2, B2, C2]
34
+ """
35
+ if isinstance(sequences[0], (list, tuple)):
36
+ sequences = sequences[0]
37
+
38
+ if len(sequences) == 1:
39
+ return sequences[0]
40
+
41
+ # All sequences should have the same shape
42
+ assert all(s.shape == sequences[0].shape for s in sequences)
43
+
44
+ # Get the original shape
45
+ batch_size, seq_len, *rest = sequences[0].shape
46
+
47
+ # Stack sequences along a new dimension
48
+ stacked = torch.stack(sequences, dim=2)
49
+
50
+ # Reshape to interleave
51
+ reshaped = stacked.view(batch_size, seq_len * len(sequences), *rest)
52
+
53
+ return reshaped
54
+
55
+
56
+ class GatedDeltaProduct(nn.Module):
57
+ """
58
+ Generalized version of GatedDoubleDeltaNet that supports arbitrary number of householder transformations.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ hidden_size: int = 2048,
64
+ expand_v: float = 2,
65
+ head_dim: int = 256,
66
+ num_heads: int = 6,
67
+ num_householder: int = 2, # New parameter for number of householder transformations
68
+ mode: str = "chunk",
69
+ use_gate: bool = True,
70
+ use_forget_gate: bool = True, # when true Gated DeltaProduct, when false DeltaProduct
71
+ use_short_conv: bool = True,
72
+ conv_size: int = 4,
73
+ conv_bias: bool = False,
74
+ layer_idx: int | None = None,
75
+ norm_eps: float = 1e-5,
76
+ allow_neg_eigval: bool = False, # when true (Gated) DeltaProduct [-1, 1], when false (Gated) DeltaProduct [0, 1]
77
+ **kwargs,
78
+ ) -> None:
79
+ super().__init__()
80
+
81
+ self.mode = mode
82
+ self.hidden_size = hidden_size
83
+ self.expand_v = expand_v
84
+ self.use_gate = use_gate
85
+ self.use_short_conv = use_short_conv
86
+ self.conv_size = conv_size
87
+ self.conv_bias = conv_bias
88
+ self.head_dim = head_dim
89
+ self.num_heads = num_heads
90
+ self.num_householder = num_householder
91
+ self.allow_neg_eigval = allow_neg_eigval
92
+ self.use_forget_gate = use_forget_gate
93
+ self.key_dim = self.num_heads * self.head_dim
94
+ self.value_dim = int(self.key_dim * self.expand_v)
95
+ self.head_qk_dim = head_dim
96
+ self.head_v_dim = int(head_dim * self.expand_v)
97
+ self.layer_idx = layer_idx
98
+ self.silu = nn.SiLU()
99
+ assert mode in ["chunk", "fused_recurrent"], f"Not supported mode `{mode}`."
100
+ # Create multiple projection layers for each householder transformation
101
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
102
+
103
+ self.k_projs = nn.ModuleList(
104
+ [
105
+ nn.Linear(hidden_size, self.key_dim, bias=False)
106
+ for _ in range(num_householder)
107
+ ]
108
+ )
109
+ self.v_projs = nn.ModuleList(
110
+ [
111
+ nn.Linear(hidden_size, self.value_dim, bias=False)
112
+ for _ in range(num_householder)
113
+ ]
114
+ )
115
+ self.b_projs = nn.ModuleList(
116
+ [
117
+ nn.Linear(hidden_size, self.num_heads, bias=False)
118
+ for _ in range(num_householder)
119
+ ]
120
+ )
121
+ if use_short_conv:
122
+ self.q_conv1ds = nn.ModuleList(
123
+ [
124
+ ShortConvolution(
125
+ hidden_size=self.key_dim,
126
+ kernel_size=conv_size,
127
+ activation="silu",
128
+ )
129
+ for _ in range(num_householder)
130
+ ]
131
+ )
132
+ self.k_conv1ds = nn.ModuleList(
133
+ [
134
+ ShortConvolution(
135
+ hidden_size=self.key_dim,
136
+ kernel_size=conv_size,
137
+ activation="silu",
138
+ )
139
+ for _ in range(num_householder)
140
+ ]
141
+ )
142
+ self.v_conv1ds = nn.ModuleList(
143
+ [
144
+ ShortConvolution(
145
+ hidden_size=self.value_dim,
146
+ kernel_size=conv_size,
147
+ activation="silu",
148
+ )
149
+ for _ in range(num_householder)
150
+ ]
151
+ )
152
+
153
+ if self.use_forget_gate:
154
+ self.a_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
155
+ A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16)
156
+ A_log = torch.log(A)
157
+ self.A_log = nn.Parameter(A_log)
158
+ self.A_log._no_weight_decay = True
159
+
160
+ # Initialize dt parameters
161
+ dt_min = 0.001
162
+ dt_max = 0.1
163
+ dt_init_floor = 1e-4
164
+ dt = torch.exp(
165
+ torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min))
166
+ + math.log(dt_min)
167
+ )
168
+ dt = torch.clamp(dt, min=dt_init_floor)
169
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
170
+ self.dt_bias = nn.Parameter(inv_dt)
171
+ self.dt_bias._no_weight_decay = True
172
+
173
+ if use_gate:
174
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
175
+ self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps)
176
+ else:
177
+ self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
178
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
179
+ self.k_id = torch.nn.Identity()
180
+ self.apply(self._initialize_weights)
181
+
182
+ def _initialize_weights(self, module: nn.Module):
183
+ if getattr(module, "_is_hf_initialized", False):
184
+ return
185
+ if isinstance(module, nn.Linear):
186
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
187
+ if module.bias is not None:
188
+ nn.init.zeros_(module.bias)
189
+ module._is_hf_initialized = True
190
+
191
+ def forward(
192
+ self,
193
+ hidden_states: torch.Tensor,
194
+ attention_mask: Optional[torch.Tensor] = None,
195
+ past_key_values: Optional[Cache] = None,
196
+ use_cache: Optional[bool] = False,
197
+ output_attentions: Optional[bool] = False,
198
+ **kwargs: Unpack[Dict],
199
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
200
+ if attention_mask is not None:
201
+ assert len(attention_mask.shape) == 2, (
202
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
203
+ "for padding purposes (0 indicating padding)."
204
+ )
205
+
206
+ mode = (
207
+ "chunk" # 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
208
+ )
209
+ if self.training:
210
+ assert mode == "chunk", "Only chunk mode is supported in training."
211
+
212
+ last_state = None
213
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
214
+ last_state = past_key_values[self.layer_idx]
215
+
216
+ # Process each householder transformation
217
+ ks, vs, betas = [], [], []
218
+ conv_states = []
219
+
220
+ for i in range(self.num_householder):
221
+ if self.use_short_conv:
222
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
223
+ if last_state is not None:
224
+ conv_state_q, conv_state_k, conv_state_v = last_state["conv_state"][
225
+ i
226
+ ]
227
+ conv_mask = (
228
+ attention_mask[:, -hidden_states.shape[1]:]
229
+ if attention_mask is not None
230
+ else None
231
+ )
232
+
233
+ k, conv_state_k = self.k_conv1ds[i](
234
+ x=self.k_projs[i](hidden_states),
235
+ mask=conv_mask,
236
+ cache=conv_state_k,
237
+ output_final_state=use_cache,
238
+ )
239
+ v, conv_state_v = self.v_conv1ds[i](
240
+ x=self.v_projs[i](hidden_states),
241
+ mask=conv_mask,
242
+ cache=conv_state_v,
243
+ output_final_state=use_cache,
244
+ )
245
+ conv_states.append((conv_state_q, conv_state_k, conv_state_v))
246
+ else:
247
+ k = self.silu(self.k_projs[i](hidden_states))
248
+ v = self.silu(self.v_projs[i](hidden_states))
249
+
250
+ ks.append(k)
251
+ vs.append(v)
252
+
253
+ beta = self.b_projs[i](
254
+ hidden_states
255
+ ).sigmoid() # bs, sequence_length, num_heads
256
+ if attention_mask is not None:
257
+ beta = beta.mul(attention_mask[:, -hidden_states.shape[1]:, None])
258
+ if self.allow_neg_eigval:
259
+ beta = beta * 2
260
+ betas.append(beta)
261
+
262
+ if self.use_short_conv:
263
+ q, conv_state_q = self.q_conv1ds[0](
264
+ x=self.q_proj(hidden_states),
265
+ mask=conv_mask,
266
+ cache=conv_state_q,
267
+ output_final_state=use_cache,
268
+ )
269
+ else:
270
+ q = self.silu(self.q_proj(hidden_states))
271
+ q = interleave_multiple_sequences(
272
+ [torch.zeros_like(q)] * (self.num_householder - 1) + [q]
273
+ )
274
+ # Interleave all sequences
275
+ k = interleave_multiple_sequences(ks)
276
+ v = interleave_multiple_sequences(vs)
277
+ beta = interleave_multiple_sequences(betas)
278
+
279
+ q, k, v = (
280
+ rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in (q, k, v)
281
+ )
282
+
283
+ recurrent_state = (
284
+ last_state["recurrent_state"] if last_state is not None else None
285
+ )
286
+ offsets = kwargs.get("offsets")
287
+
288
+ if mode == "chunk":
289
+ if self.use_forget_gate:
290
+ g = -self.A_log.float().exp() * F.softplus(
291
+ self.a_proj(hidden_states).float() + self.dt_bias
292
+ )
293
+ if attention_mask is not None:
294
+ g = g.mul(attention_mask[:, -g.shape[-2]:, None])
295
+
296
+ # Interleave g with zeros for non-first transformations
297
+ g = interleave_multiple_sequences(
298
+ [g] + [torch.zeros_like(g)] * (self.num_householder - 1)
299
+ )
300
+
301
+ o, recurrent_state = chunk_gated_delta_rule(
302
+ q=q,
303
+ k=k,
304
+ v=v,
305
+ g=g,
306
+ beta=beta,
307
+ initial_state=recurrent_state,
308
+ output_final_state=use_cache,
309
+ cu_seqlens=offsets,
310
+ head_first=False,
311
+ use_qk_l2norm_in_kernel=True
312
+ )
313
+ else:
314
+ o, recurrent_state = chunk_delta_rule(
315
+ q=q,
316
+ k=k,
317
+ v=v,
318
+ beta=beta,
319
+ initial_state=recurrent_state,
320
+ output_final_state=use_cache,
321
+ cu_seqlens=offsets,
322
+ head_first=False,
323
+ use_qk_l2norm_in_kernel=True
324
+ )
325
+ else:
326
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
327
+
328
+ # Take every nth element for n householder transformations
329
+ o = o[:, self.num_householder - 1:: self.num_householder, :]
330
+
331
+ if past_key_values is not None:
332
+ past_key_values.update(
333
+ recurrent_state=recurrent_state,
334
+ conv_state=conv_states if self.use_short_conv else None,
335
+ layer_idx=self.layer_idx,
336
+ offset=q.shape[2],
337
+ )
338
+
339
+ if self.use_gate:
340
+ g = rearrange(
341
+ self.g_proj(hidden_states),
342
+ "... (h d) -> ... h d",
343
+ h=self.num_heads,
344
+ )
345
+ o = self.o_norm(o, g)
346
+ else:
347
+ o = self.o_norm(o)
348
+ o = rearrange(o, "b t h d -> b t (h d)")
349
+ o = self.o_proj(o)
350
+
351
+ return o, None, past_key_values
fla/layers/hgrn.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # "Hierarchically Gated Recurrent Neural Network for Sequence Modeling" [https://arxiv.org/abs/2311.04823]
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from fla.modules import FusedRMSNormGated, ShortConvolution
15
+ from fla.modules.activations import swiglu
16
+ from fla.ops.hgrn import chunk_hgrn, fused_recurrent_hgrn
17
+
18
+ if TYPE_CHECKING:
19
+ from transformers.processing_utils import Unpack
20
+
21
+ from fla.models.utils import Cache
22
+
23
+
24
+ class HGRNAttention(nn.Module):
25
+
26
+ def __init__(
27
+ self,
28
+ mode: str = 'chunk',
29
+ hidden_size: int = 1024,
30
+ expand_ratio: Optional[int] = 1,
31
+ use_short_conv: bool = False,
32
+ conv_size: int = 4,
33
+ conv_bias: bool = False,
34
+ elementwise_affine: Optional[bool] = True,
35
+ norm_eps: float = 1e-5,
36
+ layer_idx: int = None
37
+ ) -> HGRNAttention:
38
+ super().__init__()
39
+
40
+ self.mode = mode
41
+ self.hidden_size = hidden_size
42
+ self.expand_ratio = expand_ratio
43
+ self.input_dim = int(hidden_size * expand_ratio)
44
+
45
+ self.use_short_conv = use_short_conv
46
+ self.conv_size = conv_size
47
+ self.conv_bias = conv_bias
48
+
49
+ self.layer_idx = layer_idx
50
+
51
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
52
+
53
+ self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
54
+ self.f_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
55
+ self.g_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
56
+
57
+ if use_short_conv:
58
+ self.conv_size = conv_size
59
+ self.q_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
60
+ self.f_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
61
+ self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
62
+
63
+ self.g_norm = FusedRMSNormGated(
64
+ hidden_size=self.input_dim,
65
+ elementwise_affine=elementwise_affine,
66
+ eps=norm_eps
67
+ )
68
+ self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False)
69
+
70
+ def forward(
71
+ self,
72
+ hidden_states: torch.Tensor,
73
+ attention_mask: Optional[torch.Tensor] = None,
74
+ past_key_values: Optional[Cache] = None,
75
+ use_cache: Optional[bool] = False,
76
+ output_attentions: Optional[bool] = False,
77
+ lower_bound: Optional[torch.Tensor] = None,
78
+ **kwargs: Unpack[Dict]
79
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
80
+ if attention_mask is not None:
81
+ assert len(attention_mask.shape) == 2, (
82
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
83
+ "for padding purposes (0 indicating padding). "
84
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
85
+ )
86
+
87
+ # launching the triton kernel for just one token will actually be slower
88
+ mode = 'fused_recurrent' if not self.training and hidden_states.shape[1] <= 64 else self.mode
89
+
90
+ last_state = None
91
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
92
+ last_state = past_key_values[self.layer_idx]
93
+
94
+ cu_seqlens = kwargs.get('cu_seqlens', None)
95
+ if self.use_short_conv:
96
+ conv_state_i, conv_state_f = None, None
97
+ if last_state is not None:
98
+ conv_state_i, conv_state_f = last_state['conv_state']
99
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
100
+ i, conv_state_i = self.i_conv1d(
101
+ x=self.i_proj(hidden_states),
102
+ mask=conv_mask,
103
+ cache=conv_state_i,
104
+ output_final_state=use_cache,
105
+ cu_seqlens=cu_seqlens
106
+ )
107
+ f, conv_state_f = self.f_conv1d(
108
+ x=self.f_proj(hidden_states),
109
+ mask=conv_mask,
110
+ cache=conv_state_f,
111
+ output_final_state=use_cache,
112
+ cu_seqlens=cu_seqlens
113
+ )
114
+ else:
115
+ i = self.i_proj(hidden_states)
116
+ f = self.f_proj(hidden_states)
117
+
118
+ # the lower bound for the first layer is zero
119
+ if lower_bound is None or self.layer_idx == 0:
120
+ i, f = swiglu(i, 1 - f.sigmoid()), F.logsigmoid(f)
121
+ else:
122
+ g = lower_bound + (1 - lower_bound) * f.sigmoid()
123
+ i, f = swiglu(i, 1 - g), g.log()
124
+
125
+ # dealing with left-padding
126
+ if attention_mask is not None:
127
+ i = i.mul_(attention_mask[:, -i.shape[-2]:, None])
128
+
129
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
130
+ if mode == 'chunk':
131
+ if cu_seqlens is not None:
132
+ raise NotImplementedError("Chunk mode does not support variable-length sequences.")
133
+ o, recurrent_state = chunk_hgrn(
134
+ x=i,
135
+ g=f,
136
+ initial_state=recurrent_state,
137
+ output_final_state=use_cache,
138
+ )
139
+ elif mode == 'fused_recurrent':
140
+ o, recurrent_state = fused_recurrent_hgrn(
141
+ x=i,
142
+ g=f,
143
+ initial_state=recurrent_state,
144
+ output_final_state=use_cache,
145
+ cu_seqlens=cu_seqlens
146
+ )
147
+ else:
148
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
149
+
150
+ if past_key_values is not None:
151
+ past_key_values.update(
152
+ recurrent_state=recurrent_state,
153
+ conv_state=(conv_state_i, conv_state_f) if self.use_short_conv else None,
154
+ layer_idx=self.layer_idx,
155
+ offset=i.shape[2]
156
+ )
157
+
158
+ o = self.g_norm(o, self.g_proj(hidden_states))
159
+ o = self.o_proj(o)
160
+
161
+ return o, None, past_key_values
162
+
163
+ def state_size(self, **kwargs) -> int:
164
+ state_size = self.hidden_size
165
+ for module in self.children():
166
+ if isinstance(module, ShortConvolution):
167
+ state_size += module.state_size
168
+ return state_size
fla/layers/hgrn2.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # "HGRN2: Gated Linear RNNs with State Expansion"[https://arxiv.org/abs/2404.07904]
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from einops import rearrange
14
+
15
+ from fla.modules import RMSNorm, ShortConvolution
16
+ from fla.modules.activations import swish
17
+ from fla.modules.layernorm import rms_norm_linear
18
+ from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
19
+
20
+ if TYPE_CHECKING:
21
+ from transformers.processing_utils import Unpack
22
+
23
+ from fla.models.utils import Cache
24
+
25
+
26
+ class HGRN2Attention(nn.Module):
27
+
28
+ def __init__(
29
+ self,
30
+ mode: str = 'chunk',
31
+ hidden_size: int = 1024,
32
+ num_heads: Optional[int] = None,
33
+ expand_ratio: Optional[int] = 128,
34
+ use_short_conv: bool = False,
35
+ conv_size: int = 4,
36
+ conv_bias: bool = False,
37
+ elementwise_affine: Optional[bool] = True,
38
+ norm_eps: float = 1e-5,
39
+ layer_idx: int = None
40
+ ) -> HGRN2Attention:
41
+ super().__init__()
42
+
43
+ self.mode = mode
44
+ self.hidden_size = hidden_size
45
+
46
+ if expand_ratio is None and num_heads is not None:
47
+ expand_ratio = hidden_size // num_heads
48
+ elif expand_ratio is not None and num_heads is None:
49
+ num_heads = hidden_size // expand_ratio
50
+ elif expand_ratio is None and num_heads is None:
51
+ raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.")
52
+ self.num_heads = num_heads
53
+ self.expand_ratio = expand_ratio
54
+
55
+ self.use_short_conv = use_short_conv
56
+ self.conv_size = conv_size
57
+ self.conv_bias = conv_bias
58
+
59
+ self.forget_dim = int(self.num_heads * self.expand_ratio)
60
+ self.input_dim = hidden_size
61
+ self.layer_idx = layer_idx
62
+
63
+ assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
64
+ assert self.forget_dim % num_heads == 0, f"forget dim must be divisible by num_heads of {num_heads}"
65
+ assert self.input_dim % num_heads == 0, f"input dim must be divisible by num_heads of {num_heads}"
66
+
67
+ self.head_f_dim = self.expand_ratio
68
+ self.head_i_dim = self.hidden_size // num_heads
69
+
70
+ self.q_proj = nn.Linear(hidden_size, self.forget_dim, bias=False)
71
+ self.f_proj = nn.Linear(hidden_size, self.forget_dim, bias=False)
72
+ self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
73
+
74
+ if use_short_conv:
75
+ self.conv_size = conv_size
76
+ self.q_conv1d = ShortConvolution(self.forget_dim, conv_size, activation=None)
77
+ self.f_conv1d = ShortConvolution(self.forget_dim, conv_size, activation=None)
78
+ self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
79
+
80
+ self.g_norm = RMSNorm(hidden_size=self.hidden_size, elementwise_affine=elementwise_affine, eps=norm_eps)
81
+ self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False)
82
+
83
+ def forward(
84
+ self,
85
+ hidden_states: torch.Tensor,
86
+ attention_mask: Optional[torch.Tensor] = None,
87
+ past_key_values: Optional[Cache] = None,
88
+ use_cache: Optional[bool] = False,
89
+ output_attentions: Optional[bool] = False,
90
+ lower_bound: Optional[torch.Tensor] = None,
91
+ **kwargs: Unpack[Dict]
92
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
93
+ if attention_mask is not None:
94
+ assert len(attention_mask.shape) == 2, (
95
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
96
+ "for padding purposes (0 indicating padding). "
97
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
98
+ )
99
+
100
+ # launching the triton kernel for just one token will actually be slower
101
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
102
+
103
+ last_state = None
104
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
105
+ last_state = past_key_values[self.layer_idx]
106
+
107
+ cu_seqlens = kwargs.get('cu_seqlens', None)
108
+ if self.use_short_conv:
109
+ conv_state_q, conv_state_f, conv_state_i = None, None, None
110
+ if last_state is not None:
111
+ conv_state_q, conv_state_f, conv_state_i = last_state['conv_state']
112
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
113
+ q, conv_state_q = self.q_conv1d(
114
+ x=self.q_proj(hidden_states),
115
+ mask=conv_mask,
116
+ cache=conv_state_q,
117
+ output_final_state=use_cache,
118
+ cu_seqlens=cu_seqlens
119
+ )
120
+ f, conv_state_f = self.f_conv1d(
121
+ x=self.f_proj(hidden_states),
122
+ mask=conv_mask,
123
+ cache=conv_state_f,
124
+ output_final_state=use_cache,
125
+ cu_seqlens=cu_seqlens
126
+ )
127
+ i, conv_state_i = self.i_conv1d(
128
+ x=self.i_proj(hidden_states),
129
+ mask=conv_mask,
130
+ cache=conv_state_i,
131
+ output_final_state=use_cache,
132
+ cu_seqlens=cu_seqlens
133
+ )
134
+ else:
135
+ q = self.q_proj(hidden_states)
136
+ f = self.f_proj(hidden_states)
137
+ i = self.i_proj(hidden_states)
138
+
139
+ # dealing with left-padding
140
+ if attention_mask is not None:
141
+ i = i.mul_(attention_mask[:, -i.shape[-2]:, None])
142
+
143
+ q = swish(q)
144
+
145
+ # improve precision
146
+ f = f.float()
147
+
148
+ # the lower bound for the first layer is zero
149
+ if lower_bound is None or self.layer_idx == 0:
150
+ k, g = 1 - f.sigmoid(), F.logsigmoid(f)
151
+ else:
152
+ g = lower_bound + (1 - lower_bound) * f.sigmoid()
153
+ k, g = 1 - g, g.log()
154
+
155
+ q, k, g = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_f_dim), (q, k.to(i), g))
156
+ i = rearrange(i, '... (h d) -> ... h d', d=self.head_i_dim)
157
+
158
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
159
+ if mode == 'fused_recurrent':
160
+ o, recurrent_state = fused_recurrent_gla(
161
+ q=q,
162
+ k=k,
163
+ v=i,
164
+ gk=g,
165
+ initial_state=recurrent_state,
166
+ output_final_state=use_cache,
167
+ cu_seqlens=cu_seqlens,
168
+ head_first=False
169
+ )
170
+ elif mode == 'fused_chunk':
171
+ o, recurrent_state = fused_chunk_gla(
172
+ q=q,
173
+ k=k,
174
+ v=i,
175
+ g=g,
176
+ initial_state=recurrent_state,
177
+ output_final_state=use_cache,
178
+ head_first=False
179
+ )
180
+ elif mode == 'chunk':
181
+ o, recurrent_state = chunk_gla(
182
+ q=q,
183
+ k=k,
184
+ v=i,
185
+ g=g,
186
+ initial_state=recurrent_state,
187
+ output_final_state=use_cache,
188
+ cu_seqlens=cu_seqlens,
189
+ head_first=False
190
+ )
191
+ else:
192
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
193
+
194
+ if past_key_values is not None:
195
+ past_key_values.update(
196
+ recurrent_state=recurrent_state,
197
+ conv_state=(conv_state_q, conv_state_f, conv_state_i) if self.use_short_conv else None,
198
+ layer_idx=self.layer_idx,
199
+ offset=q.shape[1]
200
+ )
201
+
202
+ o = rearrange(o, '... h d -> ... (h d)')
203
+ o = rms_norm_linear(o, self.g_norm.weight, self.g_norm.bias, self.o_proj.weight, self.o_proj.bias)
204
+ return o, None, past_key_values
205
+
206
+ def state_size(self, **kwargs) -> int:
207
+ state_size = self.forget_dim * self.head_i_dim
208
+ for module in self.children():
209
+ if isinstance(module, ShortConvolution):
210
+ state_size += module.state_size
211
+ return state_size
fla/layers/lightnet.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # ["You Only Scan Once: Efficient Multi-dimension Sequential Modeling with LightNet"](https://arxiv.org/abs/2405.21022)
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from einops import rearrange
14
+
15
+ from fla.modules import FusedRMSNormGated, ShortConvolution
16
+ from fla.modules.fused_norm_gate import rms_norm_swish_gate_linear
17
+ from fla.ops.gla import chunk_gla, fused_recurrent_gla
18
+
19
+ if TYPE_CHECKING:
20
+ from transformers.processing_utils import Unpack
21
+
22
+ from fla.models.utils import Cache
23
+
24
+
25
+ class LightNetAttention(nn.Module):
26
+
27
+ def __init__(
28
+ self,
29
+ mode: str = 'chunk',
30
+ hidden_size: int = 1024,
31
+ num_heads: Optional[int] = None,
32
+ expand_ratio: Optional[int] = 128,
33
+ use_short_conv: bool = False,
34
+ conv_size: int = 4,
35
+ conv_bias: bool = False,
36
+ gate_low_rank_dim: int = 128,
37
+ elementwise_affine: Optional[bool] = True,
38
+ norm_eps: float = 1e-5,
39
+ layer_idx: int = None
40
+ ) -> LightNetAttention:
41
+ super().__init__()
42
+
43
+ self.mode = mode
44
+ self.hidden_size = hidden_size
45
+
46
+ if expand_ratio is None and num_heads is not None:
47
+ expand_ratio = hidden_size // num_heads
48
+ elif expand_ratio is not None and num_heads is None:
49
+ num_heads = hidden_size // expand_ratio
50
+ elif expand_ratio is None and num_heads is None:
51
+ raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.")
52
+ self.num_heads = num_heads
53
+ self.expand_ratio = expand_ratio
54
+
55
+ self.use_short_conv = use_short_conv
56
+ self.conv_size = conv_size
57
+ self.conv_bias = conv_bias
58
+
59
+ self.key_dim = int(self.num_heads * self.expand_ratio)
60
+ self.value_dim = hidden_size
61
+ self.gate_low_rank_dim = gate_low_rank_dim
62
+ self.layer_idx = layer_idx
63
+
64
+ assert mode in ['chunk', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
65
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
66
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
67
+
68
+ self.head_f_dim = self.expand_ratio
69
+ self.head_i_dim = self.hidden_size // num_heads
70
+
71
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
72
+ self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
73
+ self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
74
+
75
+ if use_short_conv:
76
+ self.conv_size = conv_size
77
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation=None)
78
+ self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation=None)
79
+ self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation=None)
80
+
81
+ self.g_proj = nn.Sequential(
82
+ nn.Linear(hidden_size, gate_low_rank_dim, bias=False),
83
+ nn.Linear(gate_low_rank_dim, hidden_size, bias=False)
84
+ )
85
+ self.g_norm = FusedRMSNormGated(
86
+ hidden_size=hidden_size,
87
+ elementwise_affine=elementwise_affine,
88
+ eps=norm_eps
89
+ )
90
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
91
+
92
+ def forward(
93
+ self,
94
+ hidden_states: torch.Tensor,
95
+ attention_mask: Optional[torch.Tensor] = None,
96
+ past_key_values: Optional[Cache] = None,
97
+ use_cache: Optional[bool] = False,
98
+ output_attentions: Optional[bool] = False,
99
+ **kwargs: Unpack[Dict]
100
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
101
+ if attention_mask is not None:
102
+ assert len(attention_mask.shape) == 2, (
103
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
104
+ "for padding purposes (0 indicating padding). "
105
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
106
+ )
107
+
108
+ # launching the triton kernel for just one token will actually be slower
109
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
110
+
111
+ last_state = None
112
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
113
+ last_state = past_key_values[self.layer_idx]
114
+
115
+ cu_seqlens = kwargs.get('cu_seqlens', None)
116
+ if self.use_short_conv:
117
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
118
+ if last_state is not None:
119
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
120
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
121
+ q, conv_state_q = self.q_conv1d(
122
+ x=self.q_proj(hidden_states),
123
+ mask=conv_mask,
124
+ cache=conv_state_q,
125
+ output_final_state=use_cache,
126
+ cu_seqlens=cu_seqlens
127
+ )
128
+ k, conv_state_k = self.k_conv1d(
129
+ x=self.k_proj(hidden_states),
130
+ mask=conv_mask,
131
+ cache=conv_state_k,
132
+ output_final_state=use_cache,
133
+ cu_seqlens=cu_seqlens
134
+ )
135
+ v, conv_state_v = self.v_conv1d(
136
+ x=self.v_proj(hidden_states),
137
+ mask=conv_mask,
138
+ cache=conv_state_v,
139
+ output_final_state=use_cache,
140
+ cu_seqlens=cu_seqlens
141
+ )
142
+ else:
143
+ q = self.q_proj(hidden_states)
144
+ k = self.k_proj(hidden_states)
145
+ v = self.v_proj(hidden_states)
146
+
147
+ # dealing with left-padding
148
+ if attention_mask is not None:
149
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
150
+
151
+ q = F.silu(q)
152
+ q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_f_dim), (q, k))
153
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_i_dim)
154
+ # TODO: this 2 steps took huge amount of time, which should be optimized
155
+ z = k.float().logcumsumexp(1)
156
+
157
+ if cu_seqlens is not None:
158
+ raise NotImplementedError("LightNet does not support variable-length sequences for now.")
159
+ k, g = torch.exp(k - z).to(k.dtype), (torch.cat((z[:, :1], z[:, :-1]), 1) - z).to(k.dtype)
160
+
161
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
162
+ if mode == 'fused_recurrent':
163
+ o, recurrent_state = fused_recurrent_gla(
164
+ q=q,
165
+ k=k,
166
+ v=v,
167
+ gk=g,
168
+ initial_state=recurrent_state,
169
+ output_final_state=use_cache,
170
+ cu_seqlens=cu_seqlens,
171
+ head_first=False
172
+ )
173
+ elif mode == 'chunk':
174
+ o, recurrent_state = chunk_gla(
175
+ q=q,
176
+ k=k,
177
+ v=v,
178
+ g=g,
179
+ initial_state=recurrent_state,
180
+ output_final_state=use_cache,
181
+ cu_seqlens=cu_seqlens,
182
+ head_first=False
183
+ )
184
+ else:
185
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
186
+
187
+ if past_key_values is not None:
188
+ past_key_values.update(
189
+ recurrent_state=recurrent_state,
190
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
191
+ layer_idx=self.layer_idx,
192
+ offset=q.shape[1]
193
+ )
194
+
195
+ o = rms_norm_swish_gate_linear(
196
+ rearrange(o, 'b t h d -> b t (h d)'),
197
+ self.g_proj(hidden_states),
198
+ self.g_norm.weight,
199
+ self.g_norm.bias,
200
+ self.o_proj.weight,
201
+ self.o_proj.bias
202
+ )
203
+ return o, None, past_key_values
204
+
205
+ def state_size(self, **kwargs) -> int:
206
+ state_size = self.key_dim * self.head_i_dim
207
+ for module in self.children():
208
+ if isinstance(module, ShortConvolution):
209
+ state_size += module.state_size
210
+ return state_size
fla/layers/linear_attn.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+
11
+ from fla.modules import RMSNorm
12
+ from fla.modules.feature_map import DPFPFeatureMap, HadamardFeatureMap, HedgehogFeatureMap, T2RFeatureMap
13
+ from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn, fused_recurrent_linear_attn
14
+
15
+
16
+ class LinearAttention(nn.Module):
17
+
18
+ def __init__(
19
+ self,
20
+ mode: str = 'chunk',
21
+ hidden_size: str = 1024,
22
+ expand_k: int = 1.0,
23
+ expand_v: int = 1.0,
24
+ num_heads: int = 8,
25
+ num_kv_heads: Optional[int] = None,
26
+ feature_map: str = 'elementwise_product',
27
+ tie_feature_map_qk: bool = False,
28
+ output_norm: str = 'rmsnorm',
29
+ norm_q: bool = False,
30
+ norm_k: bool = False,
31
+ do_feature_map_norm: bool = False,
32
+ elementwise_affine: bool = True,
33
+ norm_eps: float = 1e-5,
34
+ **kwargs
35
+ ):
36
+ super().__init__()
37
+
38
+ self.hidden_size = hidden_size
39
+ self.mode = mode
40
+ self.num_heads = num_heads
41
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
42
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
43
+ self.key_dim = int(hidden_size * expand_k)
44
+ self.value_dim = int(hidden_size * expand_v)
45
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
46
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
47
+
48
+ assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
49
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
50
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
51
+
52
+ self.head_k_dim = self.key_dim // num_heads
53
+ self.head_v_dim = self.value_dim // num_heads
54
+ self.do_feature_map_norm = do_feature_map_norm
55
+
56
+ if feature_map == 'hedgehog':
57
+ if tie_feature_map_qk:
58
+ self.feature_map_q = self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_k_dim)
59
+ else:
60
+ self.feature_map_q = HedgehogFeatureMap(head_dim=self.head_k_dim)
61
+ self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_k_dim)
62
+
63
+ elif feature_map == 't2r':
64
+ if tie_feature_map_qk:
65
+ self.feature_map_q = self.feature_map_k = T2RFeatureMap(head_dim=self.head_k_dim)
66
+ else:
67
+ self.feature_map_q = T2RFeatureMap(head_dim=self.head_k_dim)
68
+ self.feature_map_k = T2RFeatureMap(head_dim=self.head_k_dim)
69
+
70
+ elif feature_map == 'elementwise_product':
71
+ if tie_feature_map_qk:
72
+ self.feature_map_q = self.feature_map_k = HadamardFeatureMap(head_dim=self.head_k_dim)
73
+ else:
74
+ self.feature_map_q = HadamardFeatureMap(head_dim=self.head_k_dim)
75
+ self.feature_map_k = HadamardFeatureMap(head_dim=self.head_k_dim)
76
+
77
+ elif feature_map == 'dpfp':
78
+ self.feature_map_q = DPFPFeatureMap(head_dim=self.head_k_dim)
79
+ self.feature_map_k = DPFPFeatureMap(head_dim=self.head_k_dim)
80
+
81
+ elif feature_map == 'elu':
82
+ def elu(x):
83
+ return F.elu(x) + 1
84
+ self.feature_map_q = elu
85
+ self.feature_map_k = elu
86
+
87
+ elif feature_map == 'relu':
88
+ self.feature_map_q = nn.ReLU()
89
+ self.feature_map_k = nn.ReLU()
90
+
91
+ elif feature_map == 'identity':
92
+ self.feature_map_q = nn.Identity()
93
+ self.feature_map_k = nn.Identity()
94
+ else:
95
+ raise NotImplementedError(f"Not supported feature map `{feature_map}`.")
96
+
97
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
98
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
99
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
100
+
101
+ if output_norm == 'rmsnorm':
102
+ self.norm = RMSNorm(hidden_size=self.head_v_dim, elementwise_affine=elementwise_affine, eps=norm_eps)
103
+ elif output_norm == 'identity':
104
+ self.norm = nn.Identity()
105
+ else:
106
+ raise NotImplementedError(f"Not supported output norm `{output_norm}`.")
107
+
108
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
109
+
110
+ self.norm_q = norm_q
111
+ self.norm_k = norm_k
112
+
113
+ def forward(
114
+ self,
115
+ hidden_states: torch.Tensor,
116
+ **kwargs
117
+ ) -> torch.Tensor:
118
+ mode = self.mode
119
+ q = self.q_proj(hidden_states)
120
+ k = self.k_proj(hidden_states)
121
+ v = self.v_proj(hidden_states)
122
+
123
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_k_dim)
124
+ if self.num_kv_groups > 1:
125
+ k = repeat(k, '... (h d) -> ... (h g) d', d=self.head_k_dim, g=self.num_kv_groups)
126
+ v = repeat(v, '... (h d) -> ... (h g) d', d=self.head_v_dim, g=self.num_kv_groups)
127
+ else:
128
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_k_dim)
129
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim)
130
+
131
+ q = self.feature_map_q(q)
132
+ k = self.feature_map_k(k)
133
+
134
+ if self.norm_q:
135
+ q = q / (q.sum(-1, True) + 1e-4)
136
+ if self.norm_k:
137
+ k = k / (k.sum(-1, True) + 1e-4)
138
+
139
+ if mode == 'chunk':
140
+ o, final_state = chunk_linear_attn(
141
+ q=q,
142
+ k=k,
143
+ v=v,
144
+ normalize=self.do_feature_map_norm,
145
+ head_first=False
146
+ )
147
+ elif mode == 'fused_chunk':
148
+ o, final_state = fused_chunk_linear_attn(
149
+ q=q,
150
+ k=k,
151
+ v=v,
152
+ normalize=self.do_feature_map_norm,
153
+ )
154
+ elif mode == 'fused_recurrent':
155
+ o, final_state = fused_recurrent_linear_attn(
156
+ q=q,
157
+ k=k,
158
+ v=v,
159
+ normalize=self.do_feature_map_norm,
160
+ )
161
+ else:
162
+ raise NotImplementedError
163
+ o = self.norm(o)
164
+ o = rearrange(o, '... h d -> ... (h d)')
165
+ o = self.o_proj(o)
166
+ return o
fla/layers/multiscale_retention.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange, repeat
11
+ from transformers.activations import ACT2FN
12
+
13
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
14
+ from fla.modules.rotary import RotaryEmbedding
15
+ from fla.ops.retention import chunk_retention, fused_chunk_retention, fused_recurrent_retention, parallel_retention
16
+
17
+ if TYPE_CHECKING:
18
+ from fla.models.utils import Cache
19
+
20
+
21
+ class MultiScaleRetention(nn.Module):
22
+ r"""
23
+ The layer implementaion for [Retentive Network: A Successor to Transformer for Large Language Models](https://arxiv.org/pdf/2307.08621.pdf). # noqa
24
+
25
+ Args:
26
+ mode (str, Optional):
27
+ Which Retention kernel to use.
28
+ Currently available: `chunk`, `fused_recurrent`, `parallel`, and `fused_chunk`.
29
+ Default: `chunk`.
30
+ hidden_size (int, Optional):
31
+ The hidden size of the input. Default: 1024.
32
+ expand_k (float, Optional):
33
+ The expansion ratio for the key dim. Default: 1.0.
34
+ expand_v (float, Optional):
35
+ The expansion ratio for the value dim. Default: 2.0.
36
+ num_heads (int, Optional):
37
+ The number of heads. Default: 8.
38
+ num_kv_heads (int, Optional):
39
+ The number of key/value heads, used for MQA. Default: None.
40
+ feature_map (str, Optional):
41
+ Feature map function applied to queries/keys. Default: None.
42
+ use_short_conv (bool, Optional):
43
+ Whether to use short convolutions. Default: `False`.
44
+ conv_size (int, Optional):
45
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
46
+ conv_bias (bool, Optional):
47
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
48
+ use_output_gate (bool, Optional):
49
+ Whether to use output gate. Default: `True`.
50
+ gate_fn (str, Optional):
51
+ The activation function for the output gate. Default: `swish`.
52
+ elementwise_affine (bool, Optional):
53
+ If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
54
+ norm_eps (float, Optional):
55
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
56
+ fuse_norm (bool, Optional):
57
+ Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
58
+ layer_idx (int, Optional):
59
+ The index of the layer. Default: None.
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ mode: str = 'chunk',
65
+ hidden_size: int = 1024,
66
+ expand_k: float = 1.0,
67
+ expand_v: float = 2.0,
68
+ num_heads: int = 8,
69
+ num_kv_heads: Optional[int] = None,
70
+ feature_map: Optional[str] = None,
71
+ use_short_conv: bool = False,
72
+ conv_size: int = 4,
73
+ conv_bias: bool = False,
74
+ use_output_gate: bool = True,
75
+ gate_fn: str = 'swish',
76
+ elementwise_affine: Optional[bool] = True,
77
+ norm_eps: float = 1e-5,
78
+ fuse_norm: bool = True,
79
+ layer_idx: int = None,
80
+ **kwargs
81
+ ) -> MultiScaleRetention:
82
+ super().__init__()
83
+
84
+ self.mode = mode
85
+ self.hidden_size = hidden_size
86
+ self.expand_k = expand_k
87
+ self.expand_v = expand_v
88
+ self.num_heads = num_heads
89
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
90
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
91
+ self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
92
+
93
+ self.use_short_conv = use_short_conv
94
+ self.conv_size = conv_size
95
+ self.conv_bias = conv_bias
96
+ self.use_output_gate = use_output_gate
97
+
98
+ self.key_dim = int(hidden_size * expand_k)
99
+ self.value_dim = int(hidden_size * expand_v)
100
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
101
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
102
+ self.layer_idx = layer_idx
103
+
104
+ assert mode in ['chunk', 'fused_chunk', 'parallel', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
105
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
106
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
107
+
108
+ self.head_k_dim = self.key_dim // num_heads
109
+ self.head_v_dim = self.value_dim // num_heads
110
+
111
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
112
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
113
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
114
+ if self.use_output_gate:
115
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
116
+
117
+ if use_short_conv:
118
+ self.conv_size = conv_size
119
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
120
+ self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
121
+ self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
122
+
123
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
124
+
125
+ if gate_fn == 'swish' and fuse_norm and use_output_gate:
126
+ self.g_norm_swish_gate = FusedRMSNormGated(
127
+ hidden_size=self.head_v_dim,
128
+ elementwise_affine=elementwise_affine,
129
+ eps=norm_eps
130
+ )
131
+ self.fuse_norm_and_gate = True
132
+ else:
133
+ self.fuse_norm_and_gate = False
134
+ self.g_norm = RMSNorm(
135
+ hidden_size=self.head_v_dim,
136
+ elementwise_affine=elementwise_affine,
137
+ eps=norm_eps
138
+ )
139
+ self.gate_fn = ACT2FN[gate_fn]
140
+
141
+ # TODO: fix this issue
142
+ # https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py#L180
143
+ # Ideally, we would want to support arbitrary d_head_qk
144
+ assert self.head_k_dim <= 256, "head_k_dim must be less than or equal to 256"
145
+ self.rotary = RotaryEmbedding(dim=self.head_k_dim)
146
+
147
+ def forward(
148
+ self,
149
+ hidden_states: torch.Tensor,
150
+ attention_mask: Optional[torch.Tensor] = None,
151
+ past_key_values: Optional[Cache] = None,
152
+ use_cache: Optional[bool] = False,
153
+ output_attentions: Optional[bool] = False,
154
+ **kwargs
155
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
156
+ if attention_mask is not None:
157
+ assert len(attention_mask.shape) == 2, (
158
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
159
+ "for padding purposes (0 indicating padding). "
160
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
161
+ )
162
+
163
+ # launching the triton kernel for just one token will actually be slower
164
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
165
+
166
+ last_state = None
167
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
168
+ last_state = past_key_values[self.layer_idx]
169
+
170
+ cu_seqlens = kwargs.get('cu_seqlens', None)
171
+ if self.use_short_conv:
172
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
173
+ if last_state is not None:
174
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
175
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
176
+ q, conv_state_q = self.q_conv1d(
177
+ x=self.q_proj(hidden_states),
178
+ mask=conv_mask,
179
+ cache=conv_state_q,
180
+ output_final_state=use_cache,
181
+ cu_seqlens=cu_seqlens
182
+ )
183
+ k, conv_state_k = self.k_conv1d(
184
+ x=self.k_proj(hidden_states),
185
+ mask=conv_mask,
186
+ cache=conv_state_k,
187
+ output_final_state=use_cache,
188
+ cu_seqlens=cu_seqlens
189
+ )
190
+ v, conv_state_v = self.v_conv1d(
191
+ x=self.v_proj(hidden_states),
192
+ mask=conv_mask,
193
+ cache=conv_state_v,
194
+ output_final_state=use_cache,
195
+ cu_seqlens=cu_seqlens
196
+ )
197
+ else:
198
+ q = self.q_proj(hidden_states)
199
+ k = self.k_proj(hidden_states)
200
+ v = self.v_proj(hidden_states)
201
+
202
+ # dealing with left-padding
203
+ if attention_mask is not None:
204
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
205
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_k_dim)
206
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_k_dim)
207
+ if self.feature_map_fn is not None:
208
+ q, k = map(self.feature_map_fn, (q, k))
209
+
210
+ seqlen_offset, max_seqlen = 0, q.shape[1]
211
+ if past_key_values is not None:
212
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
213
+ max_seqlen = q.shape[1] + seqlen_offset
214
+
215
+ if attention_mask is not None:
216
+ # to deliminate the offsets of padding tokens
217
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
218
+ max_seqlen = q.shape[1] + max(seqlen_offset)
219
+
220
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
221
+
222
+ if self.num_kv_groups > 1:
223
+ k = repeat(k, 'b t h d -> b t (h g) d', g=self.num_kv_groups)
224
+ v = repeat(v, 'b t (h d) -> b t (h g) d', d=self.head_v_dim, g=self.num_kv_groups)
225
+ else:
226
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
227
+
228
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
229
+ if mode == 'chunk':
230
+ o, recurrent_state = chunk_retention(
231
+ q=q,
232
+ k=k,
233
+ v=v,
234
+ initial_state=recurrent_state,
235
+ output_final_state=use_cache,
236
+ cu_seqlens=cu_seqlens,
237
+ head_first=False
238
+ )
239
+ elif mode == 'fused_chunk':
240
+ o, recurrent_state = fused_chunk_retention(
241
+ q=q,
242
+ k=k,
243
+ v=v,
244
+ initial_state=recurrent_state,
245
+ output_final_state=use_cache,
246
+ cu_seqlens=cu_seqlens,
247
+ head_first=False
248
+ )
249
+ elif mode == 'parallel':
250
+ o, recurrent_state = parallel_retention(
251
+ q=q,
252
+ k=k,
253
+ v=v,
254
+ cu_seqlens=cu_seqlens,
255
+ head_first=False
256
+ )
257
+ elif mode == 'fused_recurrent':
258
+ o, recurrent_state = fused_recurrent_retention(
259
+ q=q,
260
+ k=k,
261
+ v=v,
262
+ initial_state=recurrent_state,
263
+ output_final_state=use_cache,
264
+ cu_seqlens=cu_seqlens,
265
+ head_first=False
266
+ )
267
+ else:
268
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
269
+
270
+ if past_key_values is not None:
271
+ past_key_values.update(
272
+ recurrent_state=recurrent_state,
273
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
274
+ layer_idx=self.layer_idx,
275
+ offset=q.shape[1]
276
+ )
277
+
278
+ if self.use_output_gate:
279
+ g = self.g_proj(hidden_states)
280
+ if self.fuse_norm_and_gate:
281
+ g = rearrange(g, 'b t (h d) -> b t h d', d=self.head_v_dim)
282
+ o = self.g_norm_swish_gate(o, g)
283
+ o = rearrange(o, 'b t h d -> b t (h d)')
284
+ else:
285
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
286
+ o = o * self.gate_fn(g)
287
+ else:
288
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
289
+ o = self.o_proj(o)
290
+
291
+ return o, None, past_key_values
292
+
293
+ def state_size(self, **kwargs) -> int:
294
+ state_size = self.key_dim * self.head_v_dim
295
+ for module in self.children():
296
+ if isinstance(module, ShortConvolution):
297
+ state_size += module.state_size
298
+ return state_size
fla/layers/nsa.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange
11
+ from transformers.utils import logging
12
+
13
+ from fla.modules import RotaryEmbedding
14
+ from fla.ops.nsa.parallel import parallel_nsa
15
+
16
+ if TYPE_CHECKING:
17
+ from fla.models.utils import Cache
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ class NativeSparseAttention(nn.Module):
23
+
24
+ def __init__(
25
+ self,
26
+ hidden_size: int = 2048,
27
+ num_heads: int = 64,
28
+ num_kv_heads: Optional[int] = 4,
29
+ head_dim: int = 64,
30
+ qkv_bias: bool = False,
31
+ block_size: Optional[int] = 64,
32
+ block_counts: Optional[Union[torch.LongTensor, int]] = 16,
33
+ window_size: Optional[int] = 512,
34
+ rope_theta: Optional[float] = 10000.,
35
+ max_position_embeddings: Optional[int] = None,
36
+ layer_idx: int = None
37
+ ):
38
+ super().__init__()
39
+
40
+ self.hidden_size = hidden_size
41
+ self.num_heads = num_heads
42
+ if num_kv_heads is None:
43
+ self.num_kv_heads = self.num_heads
44
+ else:
45
+ self.num_kv_heads = num_kv_heads
46
+ self.num_kv_groups = num_heads // self.num_kv_heads
47
+ self.head_dim = head_dim
48
+ self.kv_dim = self.num_kv_heads * self.head_dim
49
+ self.qkv_bias = qkv_bias
50
+
51
+ self.block_size = block_size
52
+ self.block_counts = block_counts
53
+ self.window_size = window_size
54
+ self.rope_theta = rope_theta
55
+ self.max_position_embeddings = max_position_embeddings
56
+ self.layer_idx = layer_idx
57
+
58
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=self.qkv_bias)
59
+ self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
60
+ self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
61
+ self.g_proj = nn.Linear(self.hidden_size, self.num_heads * 3, bias=False)
62
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
63
+
64
+ self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
65
+
66
+ def forward(
67
+ self,
68
+ hidden_states: torch.Tensor,
69
+ attention_mask: Optional[torch.LongTensor] = None,
70
+ past_key_values: Optional[Cache] = None,
71
+ output_attentions: bool = False,
72
+ use_cache: bool = False,
73
+ **kwargs,
74
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
75
+ if attention_mask is not None:
76
+ assert len(attention_mask.shape) == 2, (
77
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
78
+ "for padding purposes (0 indicating padding). "
79
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
80
+ )
81
+
82
+ batch_size, seq_len, _ = hidden_states.size()
83
+
84
+ q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
85
+ k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
86
+ v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
87
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=3)
88
+ g_cmp, g_slc, g_swa = g.sigmoid().unbind(-1)
89
+
90
+ cu_seqlens = kwargs.get('cu_seqlens', None)
91
+
92
+ seqlen_offset, max_seqlen = 0, seq_len
93
+ if past_key_values is not None:
94
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
95
+ max_seqlen = q.shape[1] + seqlen_offset
96
+
97
+ if attention_mask is not None:
98
+ # to deliminate the offsets of padding tokens
99
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
100
+ max_seqlen = q.shape[1] + max(seqlen_offset)
101
+
102
+ if self.max_position_embeddings is not None:
103
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
104
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
105
+
106
+ if past_key_values is not None:
107
+ cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0
108
+ k_cached, v_cached = past_key_values.update(
109
+ attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
110
+ layer_idx=self.layer_idx,
111
+ offset=seq_len,
112
+ cache_kwargs=dict(window_size=self.window_size)
113
+ )['attn_state']
114
+ if cache_has_content:
115
+ k, v = k_cached, v_cached
116
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
117
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
118
+
119
+ o = parallel_nsa(
120
+ q=q,
121
+ k=k,
122
+ v=v,
123
+ g_cmp=g_cmp,
124
+ g_slc=g_slc,
125
+ g_swa=g_swa,
126
+ block_size=self.block_size,
127
+ block_counts=self.block_counts,
128
+ window_size=self.window_size,
129
+ cu_seqlens=cu_seqlens,
130
+ head_first=False
131
+ )
132
+ o = o.reshape(batch_size, seq_len, -1)
133
+ o = self.o_proj(o)
134
+
135
+ if not output_attentions:
136
+ attentions = None
137
+
138
+ return o, attentions, past_key_values
fla/layers/rwkv6.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # "Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence"[https://arxiv.org/abs/2404.05892]
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from einops import rearrange
13
+
14
+ from fla.modules import GroupNorm
15
+ from fla.modules.activations import ACT2FN
16
+ from fla.ops.rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ class RWKV6Attention(nn.Module):
23
+
24
+ def __init__(
25
+ self,
26
+ mode: str = 'chunk',
27
+ hidden_size: int = 1024,
28
+ expand_k: float = 0.5,
29
+ expand_v: float = 1.0,
30
+ num_heads: int = 4,
31
+ gate_fn: str = 'swish',
32
+ proj_low_rank_dim: int = 32,
33
+ gate_low_rank_dim: int = 64,
34
+ fuse_norm: bool = True,
35
+ elementwise_affine: Optional[bool] = True,
36
+ norm_eps: float = 1e-5,
37
+ layer_idx: int = None,
38
+ **kwargs
39
+ ) -> RWKV6Attention:
40
+ super().__init__()
41
+
42
+ self.mode = mode
43
+ self.hidden_size = hidden_size
44
+ self.expand_k = expand_k
45
+ self.expand_v = expand_v
46
+ self.num_heads = num_heads
47
+ self.proj_low_rank_dim = proj_low_rank_dim
48
+ self.gate_low_rank_dim = gate_low_rank_dim
49
+
50
+ self.key_dim = int(hidden_size * expand_k)
51
+ self.value_dim = int(hidden_size * expand_v)
52
+ self.layer_idx = layer_idx
53
+
54
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
55
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
56
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
57
+
58
+ self.head_k_dim = self.key_dim // num_heads
59
+ self.head_v_dim = self.value_dim // num_heads
60
+
61
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
62
+ self.x_proj = nn.Sequential(
63
+ LerpLinear(hidden_size, proj_low_rank_dim * 5),
64
+ nn.Tanh(),
65
+ nn.Linear(proj_low_rank_dim * 5, hidden_size, bias=False)
66
+ )
67
+ self.x_bias = nn.Parameter(torch.zeros(5, hidden_size))
68
+
69
+ self.r_proj = DDLerpLinear(hidden_size, self.key_dim)
70
+ self.w_proj = DDLerpLinear(hidden_size, self.key_dim, low_rank_dim=gate_low_rank_dim)
71
+ self.k_proj = DDLerpLinear(hidden_size, self.key_dim)
72
+ self.v_proj = DDLerpLinear(hidden_size, self.value_dim)
73
+ self.g_proj = DDLerpLinear(hidden_size, self.value_dim)
74
+ self.bonus = nn.Parameter(torch.zeros(num_heads, self.head_k_dim))
75
+
76
+ # TODO: fuse GroupNorm and output gate
77
+ self.g_norm = GroupNorm(self.num_heads, self.value_dim, elementwise_affine=elementwise_affine, bias=True, eps=norm_eps)
78
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
79
+ self.gate_fn = ACT2FN[gate_fn]
80
+
81
+ self.apply(self._initialize_weights)
82
+
83
+ def _initialize_weights(self, module: nn.Module):
84
+ if getattr(module, "_is_hf_initialized", False):
85
+ return
86
+ if isinstance(module, nn.Linear):
87
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
88
+ if module.bias is not None:
89
+ nn.init.zeros_(module.bias)
90
+ if isinstance(module, nn.Parameter):
91
+ nn.init.xavier_uniform_(module, gain=2 ** -2.5)
92
+ module._is_hf_initialized = True
93
+
94
+ def forward(
95
+ self,
96
+ hidden_states: torch.Tensor,
97
+ attention_mask: Optional[torch.Tensor] = None,
98
+ past_key_values: Optional[Cache] = None,
99
+ use_cache: Optional[bool] = False,
100
+ output_attentions: Optional[bool] = False,
101
+ **kwargs
102
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
103
+ if attention_mask is not None:
104
+ assert len(attention_mask.shape) == 2, (
105
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
106
+ "for padding purposes (0 indicating padding). "
107
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
108
+ )
109
+
110
+ batch_size, seq_len, hidden_size = hidden_states.shape
111
+ # launching the triton kernel for just one token will actually be slower
112
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
113
+
114
+ last_state = None
115
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
116
+ last_state = past_key_values[self.layer_idx]
117
+
118
+ if attention_mask is not None:
119
+ hidden_states = hidden_states.mul_(attention_mask[:, -hidden_states.shape[-2]:, None])
120
+ if hidden_states.shape[1] == 1 and last_state is not None:
121
+ shifted = last_state['conv_state'].unsqueeze(1)
122
+ else:
123
+ shifted = self.time_shift(hidden_states)
124
+ if last_state is not None:
125
+ shifted[:, 0] = last_state['conv_state']
126
+
127
+ delta = shifted - hidden_states
128
+ x = self.x_proj[0](hidden_states, delta).view(batch_size, seq_len, -1, self.proj_low_rank_dim)
129
+ x = torch.einsum('b t n r, h n r-> b t n h', self.x_proj[1](x), self.x_proj[2].weight.view(hidden_size, 5, -1))
130
+
131
+ r, w, k, v, g = x.add_(self.x_bias).unbind(-2)
132
+ r = self.r_proj(hidden_states, r, delta)
133
+ w = self.w_proj(hidden_states, w, delta)
134
+ k = self.k_proj(hidden_states, k, delta)
135
+ v = self.v_proj(hidden_states, v, delta)
136
+ g = self.g_proj(hidden_states, g, delta)
137
+
138
+ # dealing with left-padding
139
+ if attention_mask is not None:
140
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
141
+ r, w, k = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_k_dim), (r, w, k))
142
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
143
+ w = -torch.exp(w)
144
+ u = self.bonus
145
+
146
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
147
+ cu_seqlens = kwargs.get('cu_seqlens', None)
148
+ if mode == 'fused_recurrent':
149
+ o, recurrent_state = fused_recurrent_rwkv6(
150
+ r=r,
151
+ k=k,
152
+ v=v,
153
+ w=w,
154
+ u=u,
155
+ scale=1.,
156
+ initial_state=recurrent_state,
157
+ output_final_state=use_cache,
158
+ cu_seqlens=cu_seqlens,
159
+ head_first=False
160
+ )
161
+ elif mode == 'chunk':
162
+ o, recurrent_state = chunk_rwkv6(
163
+ q=r,
164
+ k=k,
165
+ v=v,
166
+ g=w,
167
+ u=u,
168
+ scale=1.,
169
+ initial_state=recurrent_state,
170
+ output_final_state=use_cache,
171
+ cu_seqlens=cu_seqlens,
172
+ head_first=False
173
+ )
174
+ else:
175
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
176
+
177
+ if past_key_values is not None:
178
+ past_key_values.update(
179
+ recurrent_state=recurrent_state,
180
+ conv_state=hidden_states[:, -1],
181
+ layer_idx=self.layer_idx,
182
+ offset=r.shape[2]
183
+ )
184
+
185
+ o = self.g_norm(rearrange(o, '... h d -> ... (h d)')) * self.gate_fn(g)
186
+ o = self.o_proj(o)
187
+
188
+ return o, None, past_key_values
189
+
190
+
191
+ class LoRA(nn.Module):
192
+
193
+ def __init__(
194
+ self,
195
+ input_dim: int,
196
+ output_dim: int,
197
+ low_rank_dim: int,
198
+ bias: Optional[bool] = True,
199
+ activation: Optional[str] = 'tanh'
200
+ ):
201
+ super().__init__()
202
+
203
+ self.input_dim = input_dim
204
+ self.output_dim = output_dim
205
+ self.low_rank_dim = low_rank_dim
206
+ self.bias = bias
207
+
208
+ if activation is None:
209
+ self.activation = nn.Identity()
210
+ elif activation == 'sigmoid':
211
+ self.activation = nn.Sigmoid()
212
+ elif activation == 'tanh':
213
+ self.activation = nn.Tanh()
214
+ elif activation == 'relu':
215
+ self.activation = nn.ReLU()
216
+ else:
217
+ raise ValueError(f"Not supported activation `{activation}`.")
218
+
219
+ self.lora = nn.Sequential(
220
+ nn.Linear(input_dim, low_rank_dim, bias=False),
221
+ self.activation,
222
+ nn.Linear(low_rank_dim, output_dim, bias=bias)
223
+ )
224
+
225
+ def __repr__(self) -> str:
226
+ s = f"{self.__class__.__name__}("
227
+ s += f"input_dim={self.input_dim}, low_rank_dim={self.low_rank_dim}, output_dim={self.output_dim}"
228
+ if not self.bias:
229
+ s += f", bias={self.bias}"
230
+ s += ")"
231
+ return s
232
+
233
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
234
+ return self.lora(x)
235
+
236
+
237
+ class LerpLinear(nn.Module):
238
+
239
+ def __init__(
240
+ self,
241
+ input_dim: int,
242
+ output_dim: int,
243
+ low_rank_dim: Optional[int] = None
244
+ ):
245
+ super().__init__()
246
+
247
+ self.input_dim = input_dim
248
+ self.output_dim = output_dim
249
+ self.low_rank_dim = low_rank_dim
250
+
251
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
252
+ if low_rank_dim is None:
253
+ self.linear = nn.Linear(input_dim, output_dim, bias=False)
254
+ else:
255
+ self.linear = LoRA(input_dim, output_dim, low_rank_dim)
256
+ self.mu = nn.Parameter(torch.zeros(input_dim))
257
+
258
+ def __repr__(self) -> str:
259
+ s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}"
260
+ if self.low_rank_dim is not None:
261
+ s += f", low_rank_dim={self.low_rank_dim}"
262
+ s += ")"
263
+ return s
264
+
265
+ def forward(self, x: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor:
266
+ if delta is None:
267
+ shifted = self.time_shift(x)
268
+ if len(shifted.shape) == 2:
269
+ shifted = shifted.unsqueeze(1)
270
+ delta = shifted - x
271
+ return self.linear(x + delta * self.mu)
272
+
273
+
274
+ class DDLerpLinear(nn.Module):
275
+
276
+ def __init__(
277
+ self,
278
+ input_dim: int,
279
+ output_dim: int,
280
+ low_rank_dim: Optional[int] = None
281
+ ):
282
+ super().__init__()
283
+
284
+ self.input_dim = input_dim
285
+ self.output_dim = output_dim
286
+ self.low_rank_dim = low_rank_dim
287
+
288
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
289
+ if low_rank_dim is None:
290
+ self.linear = nn.Linear(input_dim, output_dim, bias=False)
291
+ else:
292
+ self.linear = LoRA(input_dim, output_dim, low_rank_dim)
293
+
294
+ def __repr__(self) -> str:
295
+ s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}"
296
+ if self.low_rank_dim is not None:
297
+ s += f", low_rank_dim={self.low_rank_dim}"
298
+ s += ")"
299
+ return s
300
+
301
+ def forward(self, x: torch.Tensor, mu: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor:
302
+ if delta is None:
303
+ shifted = self.time_shift(x)
304
+ if len(shifted.shape) == 2:
305
+ shifted = shifted.unsqueeze(1)
306
+ delta = shifted - x
307
+ return self.linear(x + delta * mu)
fla/layers/simple_gla.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from einops import rearrange, repeat
12
+
13
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
14
+ from fla.modules.activations import ACT2FN
15
+ from fla.ops.simple_gla import chunk_simple_gla, fused_recurrent_simple_gla
16
+
17
+ if TYPE_CHECKING:
18
+ from fla.models.utils import Cache
19
+
20
+
21
+ class SimpleGatedLinearAttention(nn.Module):
22
+ r"""
23
+ The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa
24
+ This layer calls the simplified GLA kernel in which the gating is head-wise instead of elementwise.
25
+
26
+ Args:
27
+ mode (str, Optional):
28
+ Which GLA kernel to use.
29
+ Currently available: `chunk`.
30
+ Default: `chunk`.
31
+ hidden_size (int, Optional):
32
+ The hidden size of the input. Default: 1024.
33
+ expand_k (float, Optional):
34
+ The expansion ratio for the key dim. Default: 1.0.
35
+ expand_v (float, Optional):
36
+ The expansion ratio for the value dim. Default: 1.0.
37
+ num_heads (int, Optional):
38
+ The number of heads. Default: 4.
39
+ num_kv_heads (int, Optional):
40
+ The number of key/value heads, used for MQA. Default: None.
41
+ feature_map (str, Optional):
42
+ Feature map function applied to queries/keys. Default: None.
43
+ use_short_conv (bool, Optional):
44
+ Whether to use short convolutions. Default: `False`.
45
+ conv_size (int, Optional):
46
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
47
+ conv_bias (bool, Optional):
48
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
49
+ gate_fn (str, Optional):
50
+ The activation function for the output gate. Default: `swish`.
51
+ elementwise_affine (bool, Optional):
52
+ If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
53
+ norm_eps (float, Optional):
54
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
55
+ gate_logit_normalizer (int, Optional):
56
+ The normalizer for the gate logits, appied after `logsigmoid`. Default: 16.
57
+ fuse_norm (bool, Optional):
58
+ Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
59
+ layer_idx (int, Optional):
60
+ The index of the layer. Default: None.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ mode: str = 'chunk',
66
+ hidden_size: int = 1024,
67
+ expand_k: float = 1.,
68
+ expand_v: float = 1.,
69
+ num_heads: int = 4,
70
+ num_kv_heads: Optional[int] = None,
71
+ feature_map: Optional[str] = None,
72
+ use_short_conv: bool = True,
73
+ conv_size: int = 4,
74
+ conv_bias: bool = False,
75
+ gate_fn: str = 'swish',
76
+ elementwise_affine: Optional[bool] = True,
77
+ norm_eps: float = 1e-5,
78
+ gate_logit_normalizer: int = 16,
79
+ fuse_norm: bool = True,
80
+ layer_idx: int = None,
81
+ ) -> SimpleGatedLinearAttention:
82
+ super().__init__()
83
+
84
+ self.mode = mode
85
+ self.hidden_size = hidden_size
86
+ self.expand_k = expand_k
87
+ self.expand_v = expand_v
88
+ self.num_heads = num_heads
89
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
90
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
91
+ self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
92
+
93
+ self.use_short_conv = use_short_conv
94
+ self.conv_size = conv_size
95
+ self.conv_bias = conv_bias
96
+
97
+ self.key_dim = int(hidden_size * expand_k)
98
+ self.value_dim = int(hidden_size * expand_v)
99
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
100
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
101
+ self.layer_idx = layer_idx
102
+
103
+ assert mode in ['chunk', "fused_recurrent"], f"Not suppoerted mode `{mode}`."
104
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
105
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
106
+
107
+ self.head_k_dim = self.key_dim // num_heads
108
+ self.head_v_dim = self.value_dim // num_heads
109
+
110
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
111
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
112
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
113
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
114
+
115
+ if use_short_conv:
116
+ self.conv_size = conv_size
117
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
118
+ self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
119
+ self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
120
+
121
+ self.gk_proj = nn.Linear(hidden_size, self.num_heads)
122
+
123
+ if gate_fn == 'swish' and fuse_norm:
124
+ self.g_norm_swish_gate = FusedRMSNormGated(
125
+ hidden_size=self.head_v_dim,
126
+ elementwise_affine=elementwise_affine,
127
+ eps=norm_eps
128
+ )
129
+ self.fuse_norm_and_gate = True
130
+ else:
131
+ self.fuse_norm_and_gate = False
132
+ self.g_norm = RMSNorm(
133
+ hidden_size=self.head_v_dim,
134
+ elementwise_affine=elementwise_affine,
135
+ eps=norm_eps
136
+ )
137
+ self.gate_fn = ACT2FN[gate_fn]
138
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
139
+
140
+ self.gate_logit_normalizer = gate_logit_normalizer
141
+
142
+ def forward(
143
+ self,
144
+ hidden_states: torch.Tensor,
145
+ attention_mask: Optional[torch.Tensor] = None,
146
+ past_key_values: Optional[Cache] = None,
147
+ use_cache: Optional[bool] = False,
148
+ output_attentions: Optional[bool] = False,
149
+ **kwargs
150
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
151
+ if attention_mask is not None:
152
+ assert len(attention_mask.shape) == 2, (
153
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
154
+ "for padding purposes (0 indicating padding). "
155
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
156
+ )
157
+
158
+ # launching the triton kernel for just one token will actually be slower
159
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
160
+
161
+ last_state = None
162
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
163
+ last_state = past_key_values[self.layer_idx]
164
+
165
+ cu_seqlens = kwargs.get('cu_seqlens', None)
166
+ if self.use_short_conv:
167
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
168
+ if last_state is not None:
169
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
170
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
171
+ q, conv_state_q = self.q_conv1d(
172
+ x=self.q_proj(hidden_states),
173
+ mask=conv_mask,
174
+ cache=conv_state_q,
175
+ output_final_state=use_cache,
176
+ cu_seqlens=cu_seqlens
177
+ )
178
+ k, conv_state_k = self.k_conv1d(
179
+ x=self.k_proj(hidden_states),
180
+ mask=conv_mask,
181
+ cache=conv_state_k,
182
+ output_final_state=use_cache,
183
+ cu_seqlens=cu_seqlens
184
+ )
185
+ v, conv_state_v = self.v_conv1d(
186
+ x=self.v_proj(hidden_states),
187
+ mask=conv_mask,
188
+ cache=conv_state_v,
189
+ output_final_state=use_cache,
190
+ cu_seqlens=cu_seqlens
191
+ )
192
+ else:
193
+ q = self.q_proj(hidden_states)
194
+ k = self.k_proj(hidden_states)
195
+ v = self.v_proj(hidden_states)
196
+ gk = self.gk_proj(hidden_states)
197
+
198
+ if self.feature_map_fn is not None:
199
+ q, k = map(self.feature_map_fn, (q, k))
200
+ # dealing with left-padding
201
+ if attention_mask is not None:
202
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
203
+ q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads)
204
+ if self.num_kv_groups > 1:
205
+ k, v = (repeat(x, '... (h d) -> ... (h g) d', h=self.num_kv_heads, g=self.num_kv_groups) for x in (k, v))
206
+ else:
207
+ k, v = (rearrange(x, '... (h d) -> ... h d', h=self.num_kv_heads) for x in (k, v))
208
+ gk = F.logsigmoid(gk) / self.gate_logit_normalizer
209
+
210
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
211
+ if mode == 'chunk':
212
+ o, recurrent_state = chunk_simple_gla(
213
+ q=q,
214
+ k=k,
215
+ v=v,
216
+ gk=gk,
217
+ initial_state=recurrent_state,
218
+ output_final_state=use_cache,
219
+ cu_seqlens=cu_seqlens,
220
+ head_first=False
221
+ )
222
+ elif mode == 'fused_recurrent':
223
+ o, recurrent_state = fused_recurrent_simple_gla(
224
+ q=q,
225
+ k=k,
226
+ v=v,
227
+ gk=gk,
228
+ initial_state=recurrent_state,
229
+ output_final_state=use_cache,
230
+ cu_seqlens=cu_seqlens,
231
+ head_first=False
232
+ )
233
+ else:
234
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
235
+
236
+ if past_key_values is not None:
237
+ past_key_values.update(
238
+ recurrent_state=recurrent_state,
239
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
240
+ layer_idx=self.layer_idx,
241
+ offset=q.shape[1]
242
+ )
243
+
244
+ g = self.g_proj(hidden_states)
245
+ if self.fuse_norm_and_gate:
246
+ g = rearrange(g, 'b t (h d) -> b t h d', h=self.num_heads)
247
+ o = self.g_norm_swish_gate(o, g)
248
+ o = rearrange(o, 'b t h d -> b t (h d)')
249
+ else:
250
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
251
+ o = o * self.gate_fn(g)
252
+ o = self.o_proj(o)
253
+
254
+ return o, None, past_key_values
255
+
256
+ def state_size(self, **kwargs) -> int:
257
+ state_size = self.key_dim * self.head_v_dim
258
+ for module in self.children():
259
+ if isinstance(module, ShortConvolution):
260
+ state_size += module.state_size
261
+ return state_size
fla/models/__init__.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from fla.models.abc import ABCConfig, ABCForCausalLM, ABCModel
4
+ from fla.models.bitnet import BitNetConfig, BitNetForCausalLM, BitNetModel
5
+ from fla.models.delta_net import DeltaNetConfig, DeltaNetForCausalLM, DeltaNetModel
6
+ from fla.models.forgetting_transformer import (
7
+ ForgettingTransformerConfig,
8
+ ForgettingTransformerForCausalLM,
9
+ ForgettingTransformerModel
10
+ )
11
+ from fla.models.gated_deltanet import GatedDeltaNetConfig, GatedDeltaNetForCausalLM, GatedDeltaNetModel
12
+ from fla.models.gated_deltaproduct import GatedDeltaProductConfig, GatedDeltaProductForCausalLM, GatedDeltaProductModel
13
+ from fla.models.gla import GLAConfig, GLAForCausalLM, GLAModel
14
+ from fla.models.gsa import GSAConfig, GSAForCausalLM, GSAModel
15
+ from fla.models.hgrn import HGRNConfig, HGRNForCausalLM, HGRNModel
16
+ from fla.models.hgrn2 import HGRN2Config, HGRN2ForCausalLM, HGRN2Model
17
+ from fla.models.lightnet import LightNetConfig, LightNetForCausalLM, LightNetModel
18
+ from fla.models.linear_attn import LinearAttentionConfig, LinearAttentionForCausalLM, LinearAttentionModel
19
+ from fla.models.mamba import MambaConfig, MambaForCausalLM, MambaModel
20
+ from fla.models.mamba2 import Mamba2Config, Mamba2ForCausalLM, Mamba2Model
21
+ from fla.models.nsa import NSAConfig, NSAForCausalLM, NSAModel
22
+ from fla.models.retnet import RetNetConfig, RetNetForCausalLM, RetNetModel
23
+ from fla.models.rwkv6 import RWKV6Config, RWKV6ForCausalLM, RWKV6Model
24
+ from fla.models.rwkv7 import RWKV7Config, RWKV7ForCausalLM, RWKV7Model
25
+ from fla.models.samba import SambaConfig, SambaForCausalLM, SambaModel
26
+ from fla.models.transformer import TransformerConfig, TransformerForCausalLM, TransformerModel
27
+ from fla.models.transformer_with_pruning import TransformerWithPruningConfig, TransformerWithPruningForCausalLM, TransformerWithPruningModel
28
+
29
+ __all__ = [
30
+ 'ABCConfig', 'ABCForCausalLM', 'ABCModel',
31
+ 'BitNetConfig', 'BitNetForCausalLM', 'BitNetModel',
32
+ 'DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel',
33
+ 'ForgettingTransformerConfig', 'ForgettingTransformerForCausalLM', 'ForgettingTransformerModel',
34
+ 'GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel',
35
+ 'GLAConfig', 'GLAForCausalLM', 'GLAModel',
36
+ 'GSAConfig', 'GSAForCausalLM', 'GSAModel',
37
+ 'HGRNConfig', 'HGRNForCausalLM', 'HGRNModel',
38
+ 'HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model',
39
+ 'LightNetConfig', 'LightNetForCausalLM', 'LightNetModel',
40
+ 'LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel',
41
+ 'MambaConfig', 'MambaForCausalLM', 'MambaModel',
42
+ 'Mamba2Config', 'Mamba2ForCausalLM', 'Mamba2Model',
43
+ 'NSAConfig', 'NSAForCausalLM', 'NSAModel',
44
+ 'RetNetConfig', 'RetNetForCausalLM', 'RetNetModel',
45
+ 'RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model',
46
+ 'RWKV7Config', 'RWKV7ForCausalLM', 'RWKV7Model',
47
+ 'SambaConfig', 'SambaForCausalLM', 'SambaModel',
48
+ 'TransformerConfig', 'TransformerForCausalLM', 'TransformerModel',
49
+ 'TransformerWithPruningConfig', 'TransformerWithPruningForCausalLM', 'TransformerWithPruningModel',
50
+ 'GatedDeltaProductConfig', 'GatedDeltaProductForCausalLM', 'GatedDeltaProductModel',
51
+ ]
fla/models/utils.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ import torch
8
+ import transformers
9
+
10
+
11
+ class Cache(transformers.cache_utils.Cache):
12
+ """
13
+ A cache used for storing hidden states produced by flash linear attention models.
14
+
15
+ It stores the states of each layer as the tensor of shape `[batch_size, key_dim, value_dim]`.
16
+ """
17
+
18
+ is_compileable = True
19
+
20
+ def __init__(
21
+ self,
22
+ seen_tokens: int = 0
23
+ ) -> Cache:
24
+ super().__init__()
25
+
26
+ self.states: List[Dict[str, Any]] = []
27
+
28
+ self._seen_tokens = seen_tokens # Used in `generate` to keep tally of how many tokens the cache has seen
29
+
30
+ def __getitem__(self, layer_idx: int) -> Dict[str, Any]:
31
+ if layer_idx < len(self):
32
+ return self.states[layer_idx]
33
+ else:
34
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
35
+
36
+ def __iter__(self):
37
+ for state in self.states:
38
+ yield state
39
+
40
+ def __len__(self):
41
+ return len(self.states)
42
+
43
+ def update(
44
+ self,
45
+ recurrent_state: torch.Tensor = None,
46
+ attn_state: Tuple[torch.Tensor, torch.Tensor] = None,
47
+ conv_state: Tuple[torch.Tensor] = None,
48
+ ffn_state: torch.Tensor = None,
49
+ layer_idx: int = 0,
50
+ offset: Optional[int] = 1,
51
+ cache_kwargs: Optional[Dict[str, Any]] = None,
52
+ ) -> Dict[str, Any]:
53
+ """
54
+ Updates the cache with the new `recurrent_state`/`attn_state`/`conv_state` for the layer `layer_idx`.
55
+
56
+ Args:
57
+ recurrent_state (`torch.Tensor`, `optional`):
58
+ The new recurrent state to cache.
59
+ attn_state (`Tuple[torch.Tensor, torch.Tensor]`, `optional`):
60
+ The new attention key/value states to cache.
61
+ conv_state (`Tuple[torch.Tensor]`, `optional`):
62
+ The new convolution state to cache.
63
+ layer_idx (`int`, defaults to 0):
64
+ The index of the layer to cache the states for.
65
+ offset (`int`, `optional`, defaults to 1):
66
+ The number of new tokens being processed.
67
+ cache_kwargs (`Dict[str, Any]`, `optional`):
68
+ Additional arguments for the cache subclass.
69
+
70
+ Return:
71
+ Dictionary of the updated state.
72
+ """
73
+
74
+ # Update the number of seen tokens
75
+ if layer_idx == 0:
76
+ self._seen_tokens += offset
77
+
78
+ if attn_state is not None:
79
+ input_size = attn_state[0].shape[-2]
80
+ window_size = cache_kwargs.get('window_size', None)
81
+ if not isinstance(attn_state, Tuple) or len(attn_state) != 2:
82
+ raise ValueError("`attn_state` must be a tuple of two tensors for key/value states")
83
+ if len(self.states) <= layer_idx:
84
+ if attn_state is not None:
85
+ if window_size is not None and input_size > window_size:
86
+ attn_state = (attn_state[0][..., -window_size:, :].contiguous(),
87
+ attn_state[1][..., -window_size:, :].contiguous())
88
+ state = dict(
89
+ recurrent_state=recurrent_state,
90
+ attn_state=attn_state,
91
+ conv_state=conv_state,
92
+ ffn_state=ffn_state
93
+ )
94
+ self.states.append(state)
95
+ else:
96
+ state = self.states[layer_idx]
97
+ if recurrent_state is not None:
98
+ state['recurrent_state'] = recurrent_state
99
+ if attn_state is not None:
100
+ key_state, value_state = state['attn_state']
101
+ if window_size is not None and key_state.shape[-2] == window_size:
102
+ # DO NOT allocate new memory if the cache is full
103
+ # roll the key/value states to the left by `input_size`
104
+ key_state = key_state.roll(-input_size, -2)
105
+ value_state = value_state.roll(-input_size, -2)
106
+ # replace the last `input_size` tokens with the new key/value states
107
+ key_state[..., -input_size:, :] = attn_state[0]
108
+ value_state[..., -input_size:, :] = attn_state[1]
109
+ attn_state = (key_state, value_state)
110
+ else:
111
+ attn_state = (torch.cat([key_state, attn_state[0]], -2),
112
+ torch.cat([value_state, attn_state[1]], -2),)
113
+ state['attn_state'] = attn_state
114
+ if conv_state is not None:
115
+ state['conv_state'] = conv_state
116
+ if ffn_state is not None:
117
+ state['ffn_state'] = ffn_state
118
+
119
+ return state
120
+
121
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
122
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
123
+ if len(self.states) <= layer_idx:
124
+ return 0
125
+ return self._seen_tokens
126
+
127
+ def get_max_length(self) -> Optional[int]:
128
+ """Returns the maximum sequence length of the cached states. Cache does not have a maximum length."""
129
+ return None
130
+
131
+ def to_legacy_cache(self) -> Tuple:
132
+ return tuple(self.states)
133
+
134
+ @classmethod
135
+ @torch.compiler.disable
136
+ def from_legacy_cache(
137
+ cls,
138
+ past_key_values: Optional[Tuple] = None,
139
+ seen_tokens: int = 0
140
+ ) -> Cache:
141
+ """Converts a cache in the legacy cache format into an equivalent `Cache`."""
142
+
143
+ cache = cls(seen_tokens)
144
+ if isinstance(past_key_values, list):
145
+ for layer_idx in range(len(past_key_values)):
146
+ cache.states.append(past_key_values[layer_idx])
147
+ return cache
fla/modules/activations.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Tri Dao, Yu Zhang, Songlin Yang.
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import triton
7
+ import triton.language as tl
8
+
9
+ from fla.ops.utils.op import exp, log
10
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, get_multiprocessor_count, input_guard
11
+
12
+ sigmoid_fwd_codestring = """
13
+ template <typename T> T sigmoid_fwd(T x) {
14
+ return 1.0f / (1.0f + ::exp(-float(x)));
15
+ }
16
+ """
17
+ sigmoid_bwd_codestring = """
18
+ template <typename T> T sigmoid_bwd(T x, T g) {
19
+ float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
20
+ return float(g) * x_sigmoid * (1.0f - x_sigmoid);
21
+ }
22
+ """
23
+
24
+ sigmoid_fwd_jit_fn = torch.cuda.jiterator._create_jit_fn(sigmoid_fwd_codestring)
25
+ sigmoid_bwd_jit_fn = torch.cuda.jiterator._create_jit_fn(sigmoid_bwd_codestring)
26
+
27
+
28
+ @torch.compiler.disable
29
+ def sigmoid_fwd(x):
30
+ return sigmoid_fwd_jit_fn(x)
31
+
32
+
33
+ @torch.compiler.disable
34
+ def sigmoid_bwd(x, g):
35
+ return sigmoid_bwd_jit_fn(x, g)
36
+
37
+
38
+ class SigmoidFunction(torch.autograd.Function):
39
+
40
+ @staticmethod
41
+ def forward(ctx, x):
42
+ ctx.save_for_backward(x)
43
+ return sigmoid_fwd(x)
44
+
45
+ @staticmethod
46
+ def backward(ctx, dout):
47
+ x, = ctx.saved_tensors
48
+ return sigmoid_bwd(x, dout)
49
+
50
+
51
+ sigmoid = SigmoidFunction.apply
52
+
53
+
54
+ @triton.autotune(
55
+ configs=[
56
+ triton.Config({}, num_warps=num_warps)
57
+ for num_warps in [1, 2, 4, 8, 16, 32]
58
+ ],
59
+ key=['D']
60
+ )
61
+ @triton.jit
62
+ def logsigmoid_fwd_kernel(
63
+ x,
64
+ y,
65
+ temperature,
66
+ T: tl.constexpr,
67
+ D: tl.constexpr,
68
+ B: tl.constexpr
69
+ ):
70
+ i = tl.program_id(0)
71
+ o_i = i * B + tl.arange(0, B)
72
+ m_i = o_i < T
73
+
74
+ b_x = tl.load(x + o_i, mask=m_i, other=0.).to(tl.float32)
75
+ b_m = tl.minimum(0., b_x)
76
+ b_z = 1. + exp(-tl.abs(b_x))
77
+ b_y = (b_m - log(b_z)) / temperature
78
+ tl.store(y + o_i, b_y.to(y.dtype.element_ty), mask=m_i)
79
+
80
+
81
+ @triton.autotune(
82
+ configs=[
83
+ triton.Config({}, num_warps=num_warps)
84
+ for num_warps in [1, 2, 4, 8, 16, 32]
85
+ ],
86
+ key=['D']
87
+ )
88
+ @triton.jit
89
+ def logsigmoid_bwd_kernel(
90
+ x,
91
+ dx,
92
+ dy,
93
+ temperature,
94
+ T: tl.constexpr,
95
+ D: tl.constexpr,
96
+ B: tl.constexpr
97
+ ):
98
+ i = tl.program_id(0)
99
+ o_i = i * B + tl.arange(0, B)
100
+ m_i = o_i < T
101
+
102
+ b_x = tl.load(x + o_i, mask=m_i, other=0.).to(tl.float32)
103
+ b_dy = tl.load(dy + o_i, mask=m_i, other=0.).to(tl.float32)
104
+ b_dx = b_dy * (1. - tl.sigmoid(b_x)) / temperature
105
+ tl.store(dx + o_i, b_dx.to(dx.dtype.element_ty), mask=m_i)
106
+
107
+
108
+ def logsigmoid_fwd(x: torch.Tensor, temperature: float = 1.) -> torch.Tensor:
109
+ T, D = x.numel(), x.shape[-1]
110
+ B = triton.next_power_of_2(triton.cdiv(T, get_multiprocessor_count(x.device.index)))
111
+ y = torch.empty_like(x)
112
+ logsigmoid_fwd_kernel[(triton.cdiv(T, B),)](
113
+ x=x,
114
+ y=y,
115
+ temperature=temperature,
116
+ T=T,
117
+ D=D,
118
+ B=B
119
+ )
120
+ return y
121
+
122
+
123
+ def logsigmoid_bwd(x: torch.Tensor, dy: torch.Tensor, temperature: float = 1.) -> torch.Tensor:
124
+ T, D = x.numel(), x.shape[-1]
125
+ B = triton.next_power_of_2(triton.cdiv(T, get_multiprocessor_count(x.device.index)))
126
+ dx = torch.empty_like(x)
127
+ logsigmoid_bwd_kernel[(triton.cdiv(T, B),)](
128
+ x=x,
129
+ dx=dx,
130
+ dy=dy,
131
+ temperature=temperature,
132
+ T=T,
133
+ D=D,
134
+ B=B
135
+ )
136
+ return dx
137
+
138
+
139
+ class LogSigmoidFunction(torch.autograd.Function):
140
+
141
+ @staticmethod
142
+ @input_guard
143
+ def forward(ctx, x, temperature):
144
+ ctx.save_for_backward(x,)
145
+ ctx.temperature = temperature
146
+ return logsigmoid_fwd(x, temperature)
147
+
148
+ @staticmethod
149
+ @input_guard
150
+ def backward(ctx, dy):
151
+ x, = ctx.saved_tensors
152
+ return logsigmoid_bwd(x, dy, ctx.temperature), None
153
+
154
+
155
+ def logsigmoid(x: torch.Tensor, temperature: float = 1.) -> torch.Tensor:
156
+ return LogSigmoidFunction.apply(x, temperature)
157
+
158
+
159
+ swish_fwd_codestring = """
160
+ template <typename T> T swish_fwd(T x) {
161
+ float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
162
+ return float(x) * x_sigmoid;
163
+ }
164
+ """
165
+ swish_bwd_codestring = """
166
+ template <typename T> T swish_bwd(T x, T g) {
167
+ float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
168
+ return float(g) * x_sigmoid * (1.0f - float(x) * x_sigmoid + float(x));
169
+ }
170
+ """
171
+
172
+ swish_fwd_jit_fn = torch.cuda.jiterator._create_jit_fn(swish_fwd_codestring)
173
+ swish_bwd_jit_fn = torch.cuda.jiterator._create_jit_fn(swish_bwd_codestring)
174
+
175
+
176
+ @torch.compiler.disable
177
+ def swish_fwd(x):
178
+ return swish_fwd_jit_fn(x)
179
+
180
+
181
+ @torch.compiler.disable
182
+ def swish_bwd(x, g):
183
+ return swish_bwd_jit_fn(x, g)
184
+
185
+
186
+ class SwishFunction(torch.autograd.Function):
187
+
188
+ @staticmethod
189
+ def forward(ctx, x):
190
+ ctx.save_for_backward(x)
191
+ return swish_fwd(x)
192
+
193
+ @staticmethod
194
+ def backward(ctx, dout):
195
+ x, = ctx.saved_tensors
196
+ return swish_bwd(x, dout)
197
+
198
+
199
+ swish = SwishFunction.apply
200
+
201
+ # 1/sqrt(2*pi)-> 0.3989423
202
+ # 1/sqrt(2) -> 0.70710678
203
+ # sqrt(2/pi) -> 0.79788456
204
+
205
+
206
+ # this function is tanh approximation of gelu
207
+ # actual gelu is:
208
+ # x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
209
+ @torch.compile
210
+ def bias_gelu(y, bias):
211
+ x = bias + y
212
+ return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)
213
+
214
+
215
+ # gradient of tanh approximation of gelu
216
+ # gradient of actual gelu is:
217
+ # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
218
+ @torch.compile
219
+ def bias_gelu_bwd(g, y, bias):
220
+ """Assume that y has shape (B, D) and bias has shape (D)"""
221
+ x = bias + y
222
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
223
+ # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
224
+ ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
225
+ 1 + tanh_out
226
+ )
227
+ grad_y = ff * g
228
+ return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)
229
+
230
+
231
+ class GeLUFunction(torch.autograd.Function):
232
+
233
+ @staticmethod
234
+ # bias is an optional argument
235
+ def forward(ctx, input, bias):
236
+ ctx.save_for_backward(input, bias)
237
+ return bias_gelu(input, bias)
238
+
239
+ @staticmethod
240
+ def backward(ctx, grad_output):
241
+ input, bias = ctx.saved_tensors
242
+ tmp = bias_gelu_bwd(grad_output, input, bias)
243
+ return tmp, tmp
244
+
245
+
246
+ bias_gelu_impl = GeLUFunction.apply
247
+
248
+
249
+ # this function is tanh approximation of gelu
250
+ # actual gelu is:
251
+ # x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
252
+ @torch.compile
253
+ def gelu_fwd(x):
254
+ return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype)
255
+
256
+
257
+ # gradient of tanh approximation of gelu
258
+ # gradient of actual gelu is:
259
+ # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
260
+ @torch.compile
261
+ def gelu_bwd(g, x):
262
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
263
+ # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
264
+ ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
265
+ 1 + tanh_out
266
+ )
267
+ return (ff * g).to(dtype=x.dtype)
268
+
269
+
270
+ class FastGeLUFunction(torch.autograd.Function):
271
+ @staticmethod
272
+ # bias is an optional argument
273
+ def forward(ctx, input):
274
+ ctx.save_for_backward(input)
275
+ return gelu_fwd(input)
276
+
277
+ @staticmethod
278
+ def backward(ctx, grad_output):
279
+ (input,) = ctx.saved_tensors
280
+ tmp = gelu_bwd(grad_output, input)
281
+ return tmp
282
+
283
+
284
+ fast_gelu_impl = FastGeLUFunction.apply
285
+
286
+
287
+ @torch.compile
288
+ def relu_bwd(g, x):
289
+ return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)
290
+
291
+
292
+ @torch.compile
293
+ def sqrelu_fwd(x):
294
+ r = F.relu(x.float())
295
+ return (r * r).to(dtype=x.dtype)
296
+
297
+
298
+ @torch.compile
299
+ def sqrelu_bwd(g, x):
300
+ return (2.0 * g * F.relu(x.float())).to(dtype=x.dtype)
301
+
302
+
303
+ class SquaredReLUFunction(torch.autograd.Function):
304
+
305
+ @staticmethod
306
+ def forward(ctx, input):
307
+ ctx.save_for_backward(input)
308
+ return sqrelu_fwd(input)
309
+
310
+ @staticmethod
311
+ def backward(ctx, grad_output):
312
+ input, = ctx.saved_tensors
313
+ return sqrelu_bwd(grad_output, input)
314
+
315
+
316
+ sqrelu = SquaredReLUFunction.apply
317
+
318
+
319
+ swiglu_fwd_codestring = """
320
+ template <typename T> T swiglu_fwd(T x, T y) {
321
+ return float(x) * float(y) / (1.0f + ::exp(-float(x)));
322
+ }
323
+ """
324
+ swiglu_bwd_codestring = """
325
+ template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
326
+ float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
327
+ dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
328
+ dy = float(x) * x_sigmoid * float(g);
329
+ }
330
+ """
331
+
332
+ swiglu_fwdbwd_codestring = """
333
+ template <typename T> T swiglu_fwdbwd(T x, T y, T g, T& dx, T& dy, T& z) {
334
+ float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
335
+ float x_swish = float(x) * x_sigmoid;
336
+ dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
337
+ dy = x_swish * float(g);
338
+ z = x_swish * float(y);
339
+ }
340
+ """
341
+
342
+
343
+ swiglu_fwd_jit_fn = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring)
344
+ swiglu_bwd_jit_fn = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2)
345
+ swiglu_fwdbwd_jit_fn = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_fwdbwd_codestring, num_outputs=3)
346
+
347
+
348
+ @torch.compiler.disable
349
+ def swiglu_fwd(x, y):
350
+ return swiglu_fwd_jit_fn(x, y)
351
+
352
+
353
+ @torch.compiler.disable
354
+ def swiglu_bwd(x, y, g):
355
+ return swiglu_bwd_jit_fn(x, y, g)
356
+
357
+
358
+ @torch.compiler.disable
359
+ def swiglu_fwdbwd(x, y, g):
360
+ return swiglu_fwdbwd_jit_fn(x, y, g)
361
+
362
+
363
+ @torch.compile
364
+ def swiglu_fwd_torch(x, y):
365
+ return (F.silu(x.float()) * y).to(x.dtype)
366
+
367
+
368
+ @torch.compile
369
+ def swiglu_bwd_torch(x, y, g):
370
+ dtype = x.dtype
371
+ x, y, g = x.float(), y.float(), g.float()
372
+ x_sigmoid = x.sigmoid()
373
+ x_swish = x * x_sigmoid
374
+ dx = x_sigmoid * (1 + x * (1.0 - x_sigmoid)) * g * y
375
+ dy = x_swish * g
376
+ return dx.to(dtype), dy.to(dtype)
377
+
378
+
379
+ @torch.compile
380
+ def swiglu_fwdbwd_torch(x, y, g):
381
+ dtype = x.dtype
382
+ x, y, g = x.float(), y.float(), g.float()
383
+ x_sigmoid = x.sigmoid()
384
+ x_swish = x * x_sigmoid
385
+ dx = x_sigmoid * (1 + x * (1.0 - x_sigmoid)) * g * y
386
+ dy = x_swish * g
387
+ z = x_swish * y
388
+ return dx.to(dtype), dy.to(dtype), z.to(dtype)
389
+
390
+
391
+ class SwiGLUFunction(torch.autograd.Function):
392
+ r"""
393
+ Swish-Gated Linear Unit (SwiGLU) function.
394
+
395
+ .. math::
396
+ \text{SwiGLU}(x, y) = swish(x) * y = \frac{x}{1 + \exp(-x)} * y
397
+ """
398
+
399
+ @staticmethod
400
+ def forward(ctx, x, y):
401
+ ctx.save_for_backward(x, y)
402
+ if torch.compiler.is_compiling() or isinstance(x, torch.distributed.tensor.DTensor):
403
+ return swiglu_fwd_torch(x, y)
404
+ else:
405
+ return swiglu_fwd(x, y)
406
+
407
+ @staticmethod
408
+ def backward(ctx, dout):
409
+ x, y = ctx.saved_tensors
410
+ if torch.compiler.is_compiling() or isinstance(x, torch.distributed.tensor.DTensor):
411
+ return swiglu_bwd_torch(x, y, dout)
412
+ else:
413
+ return swiglu_bwd(x, y, dout)
414
+
415
+
416
+ class SwiGLULinearFunction(torch.autograd.Function):
417
+ r"""
418
+ Swish-Gated Linear Unit (SwiGLU) function followed by a linear transformation.
419
+
420
+ .. math::
421
+ \text{SwiGLULinear}(x, y, W, b) = (swish(x) * y) W + b
422
+
423
+ This simple wrap discards the intermediate results of SwiGLU(x, y) to save memory.
424
+ """
425
+
426
+ @staticmethod
427
+ @autocast_custom_fwd
428
+ def forward(ctx, x, y, weight, bias):
429
+ with torch.no_grad():
430
+ if torch.compiler.is_compiling() or isinstance(x, torch.distributed.tensor.DTensor):
431
+ z = swiglu_fwd_torch(x, y)
432
+ else:
433
+ z = swiglu_fwd(x, y)
434
+ out = F.linear(z, weight, bias)
435
+ # We don't store z, will be recomputed in the backward pass to save memory
436
+ ctx.save_for_backward(x, y, weight)
437
+ ctx.linear_bias_is_none = bias is None
438
+ return out
439
+
440
+ @staticmethod
441
+ @autocast_custom_bwd
442
+ def backward(ctx, dout, *args):
443
+ x, y, weight = ctx.saved_tensors
444
+ dout = dout.reshape(-1, dout.shape[-1])
445
+ dz = F.linear(dout, weight.t()).view_as(x)
446
+ with torch.no_grad():
447
+ if torch.compiler.is_compiling() or isinstance(x, torch.distributed.tensor.DTensor):
448
+ dx, dy, z = swiglu_fwdbwd_torch(x, y, dz)
449
+ else:
450
+ dx, dy, z = swiglu_fwdbwd(x, y, dz)
451
+ dlinear_weight = torch.einsum("bo,bi->oi", dout, z.reshape(-1, z.shape[-1]))
452
+ dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
453
+ return dx, dy, dlinear_weight, dlinear_bias
454
+
455
+
456
+ swiglu = SwiGLUFunction.apply
457
+
458
+
459
+ swiglu_linear = SwiGLULinearFunction.apply
460
+
461
+
462
+ ACT2FN = {
463
+ 'relu': F.relu,
464
+ 'sigmoid': sigmoid,
465
+ 'logsigmoid': logsigmoid,
466
+ 'silu': swish,
467
+ 'swish': swish,
468
+ 'sqrelu': sqrelu,
469
+ 'gelu': fast_gelu_impl,
470
+ 'bias_gelu': bias_gelu_impl,
471
+ }
fla/modules/fused_cross_entropy.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright (c) 2023, Tri Dao.
4
+
5
+ from typing import Any, Tuple
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import triton
10
+ import triton.language as tl
11
+
12
+ from fla.ops.utils.op import exp, log
13
+ from fla.utils import input_guard
14
+
15
+ # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
16
+ # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
17
+ # version of PyTorch. The following 2 lines are for backward compatibility with
18
+ # older PyTorch.
19
+ if "all_gather_into_tensor" not in dir(torch.distributed):
20
+ torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
21
+
22
+
23
+ @triton.heuristics({
24
+ "HAS_SMOOTHING": lambda args: args["label_smoothing"] > 0.0,
25
+ })
26
+ @triton.jit
27
+ def cross_entropy_fwd_kernel(
28
+ loss_ptr, # data ptrs
29
+ lse_ptr,
30
+ z_loss_ptr,
31
+ logits_ptr,
32
+ labels_ptr,
33
+ label_smoothing,
34
+ logit_scale,
35
+ lse_square_scale,
36
+ ignore_index,
37
+ total_classes,
38
+ class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
39
+ n_cols, # shapes
40
+ n_rows,
41
+ logits_row_stride, # strides
42
+ BLOCK_SIZE: tl.constexpr,
43
+ HAS_SMOOTHING: tl.constexpr,
44
+ # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
45
+ SPLIT: tl.constexpr,
46
+ ):
47
+ row_idx = tl.program_id(0)
48
+ col_block_idx = tl.program_id(1)
49
+ logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
50
+ col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
51
+ label_idx = tl.load(labels_ptr + row_idx)
52
+ logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf"))
53
+ logits = logits.to(tl.float32) * logit_scale
54
+ max_logits = tl.max(logits, 0)
55
+ if HAS_SMOOTHING:
56
+ sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
57
+ lse = log(tl.sum(exp(logits - max_logits), 0)) + max_logits
58
+ tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)
59
+ if label_idx == ignore_index:
60
+ loss = 0.0
61
+ z_loss = 0.0
62
+ else:
63
+ label_idx -= class_start_idx
64
+ if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(
65
+ n_cols, (col_block_idx + 1) * BLOCK_SIZE
66
+ ):
67
+ logits_label = tl.load(logits_ptr + label_idx) * logit_scale
68
+ if HAS_SMOOTHING:
69
+ loss = (
70
+ (lse if not SPLIT else 0.0)
71
+ - label_smoothing * sum_logits / total_classes
72
+ - (1 - label_smoothing) * logits_label
73
+ )
74
+ else:
75
+ loss = (lse if not SPLIT else 0.0) - logits_label
76
+ else:
77
+ # If label is out of bounds, we set the CE loss to 0.0. But we still want the label_smoothing loss
78
+ if HAS_SMOOTHING:
79
+ loss = label_smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)
80
+ else:
81
+ loss = 0.0
82
+ if not SPLIT:
83
+ z_loss = lse_square_scale * lse * lse
84
+ loss += z_loss
85
+ else:
86
+ z_loss = 0.0
87
+ tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)
88
+ if not SPLIT:
89
+ tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss)
90
+
91
+
92
+ @triton.heuristics({
93
+ "HAS_SMOOTHING": lambda args: args["label_smoothing"] > 0.0,
94
+ })
95
+ @triton.jit
96
+ def cross_entropy_bwd_kernel(
97
+ dlogits_ptr, # data ptrs
98
+ dloss_ptr,
99
+ logits_ptr,
100
+ lse_ptr,
101
+ labels_ptr,
102
+ label_smoothing,
103
+ logit_scale,
104
+ lse_square_scale,
105
+ ignore_index,
106
+ total_classes,
107
+ class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
108
+ n_cols, # shapes
109
+ logits_row_stride, # strides
110
+ dlogits_row_stride,
111
+ dloss_row_stride,
112
+ BLOCK_SIZE: tl.constexpr,
113
+ HAS_SMOOTHING: tl.constexpr,
114
+ ):
115
+ row_idx = tl.program_id(0)
116
+ col_block_idx = tl.program_id(1)
117
+ logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
118
+ dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)
119
+ col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
120
+ label_idx = tl.load(labels_ptr + row_idx)
121
+ if label_idx != ignore_index:
122
+ dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
123
+ else:
124
+ dloss = 0.0
125
+ logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
126
+ tl.float32
127
+ ) * logit_scale
128
+ lse = tl.load(lse_ptr + row_idx)
129
+ probs = exp(logits - lse)
130
+ probs += 2.0 * lse_square_scale * lse * probs
131
+ label_idx -= class_start_idx
132
+ if HAS_SMOOTHING:
133
+ smooth_negative = label_smoothing / total_classes
134
+ probs = tl.where(col_offsets == label_idx, probs - (1 - label_smoothing), probs) - smooth_negative
135
+ else:
136
+ probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
137
+ tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)
138
+
139
+
140
+ def fused_cross_entropy_forward(
141
+ logits: torch.Tensor,
142
+ target: torch.Tensor,
143
+ label_smoothing: float = 0.0,
144
+ logit_scale: float = 1.0,
145
+ lse_square_scale: float = 0.0,
146
+ ignore_index: int = -100,
147
+ process_group=None,
148
+ ):
149
+ n_rows, n_cols = logits.shape
150
+ assert target.shape == (n_rows,)
151
+ world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
152
+ total_classes = world_size * n_cols
153
+ rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
154
+ class_start_idx = rank * n_cols
155
+
156
+ if logits.stride(-1) != 1:
157
+ logits = logits.contiguous()
158
+ # Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py
159
+ MAX_BLOCK_SIZE = 64 * 1024
160
+ BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
161
+ num_warps = (
162
+ 4
163
+ if BLOCK_SIZE < 2048
164
+ else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
165
+ )
166
+ # We may split the lse computation across multiple blocks, then do a reduction
167
+ # lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k)
168
+ # where having just one thread block processing more than 64k elements is slow.
169
+ split = world_size > 1 or n_cols > MAX_BLOCK_SIZE
170
+ n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE
171
+ loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,)
172
+ losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
173
+ lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
174
+ z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
175
+
176
+ cross_entropy_fwd_kernel[(n_rows, n_splits)](
177
+ losses, # data ptrs
178
+ lse,
179
+ z_losses,
180
+ logits,
181
+ target,
182
+ label_smoothing,
183
+ logit_scale,
184
+ lse_square_scale,
185
+ ignore_index,
186
+ total_classes,
187
+ class_start_idx,
188
+ n_cols, # shapes
189
+ n_rows,
190
+ logits.stride(0), # strides
191
+ BLOCK_SIZE=BLOCK_SIZE, # constants
192
+ num_warps=num_warps,
193
+ SPLIT=split
194
+ )
195
+
196
+ if split:
197
+ # If there's no label_smoothing, if target are in the vocab of this partition, losses contains
198
+ # - predicted logit, and 0 otherwise.
199
+ # If there's label_smoothing=0.1, for target in the vocab of this partition, losses contains
200
+ # -0.9 * predicted logit - 0.1 * sum logit / total_classes.
201
+ # For target not in the vocab of this partition, losses contains
202
+ # -0.1 * sum logit / total_classes.
203
+ if n_splits > 1:
204
+ lse = torch.logsumexp(lse, dim=0)
205
+ losses = losses.sum(dim=0)
206
+ if world_size > 1:
207
+ lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
208
+ torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
209
+ handle_losses = torch.distributed.all_reduce(
210
+ losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
211
+ )
212
+ lse = torch.logsumexp(lse_allgather, dim=0)
213
+ handle_losses.wait()
214
+ # After the allreduce, if there's no label_smoothing, the total losses are - predicted_logit,
215
+ # we just have to add the (global) lse.
216
+ # If there's label_smoothing=0.1, the total losses are
217
+ # -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
218
+ # Again, we just have to add the (global) lse.
219
+ losses += lse
220
+ if lse_square_scale != 0.0:
221
+ z_losses = lse_square_scale * lse.square()
222
+ z_losses.masked_fill_(target == ignore_index, 0.0)
223
+ losses += z_losses
224
+ else:
225
+ z_losses = torch.zeros_like(losses)
226
+ losses.masked_fill_(target == ignore_index, 0.0)
227
+
228
+ return losses, z_losses, lse, total_classes, class_start_idx
229
+
230
+
231
+ class CrossEntropyLossFunction(torch.autograd.Function):
232
+
233
+ @staticmethod
234
+ @input_guard
235
+ def forward(
236
+ ctx,
237
+ logits,
238
+ target,
239
+ label_smoothing=0.0,
240
+ logit_scale=1.0,
241
+ lse_square_scale=0.0,
242
+ ignore_index=-100,
243
+ inplace_backward=False,
244
+ process_group=None,
245
+ ):
246
+ losses, z_losses, lse, total_classes, class_start_idx = fused_cross_entropy_forward(
247
+ logits,
248
+ target,
249
+ label_smoothing,
250
+ logit_scale,
251
+ lse_square_scale,
252
+ ignore_index,
253
+ process_group,
254
+ )
255
+ ctx.save_for_backward(logits, lse, target)
256
+ ctx.mark_non_differentiable(z_losses)
257
+ ctx.label_smoothing = label_smoothing
258
+ ctx.logit_scale = logit_scale
259
+ ctx.lse_square_scale = lse_square_scale
260
+ ctx.ignore_index = ignore_index
261
+ ctx.total_classes = total_classes
262
+ ctx.class_start_idx = class_start_idx
263
+ ctx.inplace_backward = inplace_backward
264
+
265
+ return losses, z_losses
266
+
267
+ @staticmethod
268
+ @input_guard
269
+ def backward(ctx, grad_losses, grad_z_losses):
270
+ del grad_z_losses # z_losses are only for logging.
271
+
272
+ logits, lse, target = ctx.saved_tensors
273
+ dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)
274
+ n_rows, n_cols = logits.shape
275
+ BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)
276
+ num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)
277
+ def grid(META): return (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa
278
+ cross_entropy_bwd_kernel[grid](
279
+ dlogits, # data ptrs
280
+ grad_losses,
281
+ logits,
282
+ lse,
283
+ target,
284
+ ctx.label_smoothing,
285
+ ctx.logit_scale,
286
+ ctx.lse_square_scale,
287
+ ctx.ignore_index,
288
+ ctx.total_classes,
289
+ ctx.class_start_idx,
290
+ n_cols, # shapes
291
+ logits.stride(0), # strides
292
+ dlogits.stride(0),
293
+ grad_losses.stride(0),
294
+ BLOCK_SIZE=BLOCK_SIZE, # constants
295
+ num_warps=num_warps,
296
+ )
297
+ return dlogits, None, None, None, None, None, None, None, None
298
+
299
+
300
+ def cross_entropy_loss(
301
+ logits: torch.Tensor,
302
+ target: torch.Tensor,
303
+ label_smoothing: float = 0.0,
304
+ logit_scale: float = 1.0,
305
+ lse_square_scale: float = 0.0,
306
+ ignore_index=-100,
307
+ inplace_backward: bool = False,
308
+ process_group=None,
309
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
310
+ """
311
+ Arguments:
312
+ logits: [batch, vocab_size]
313
+ target: [batch,]
314
+ label_smoothing: float
315
+ logit_scale: float.
316
+ Multiply logits by this scale before calculating the loss.
317
+ lse_square_scale: float.
318
+ If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
319
+ This is also referred to as "z-loss".
320
+ ignore_index: int.
321
+ If target == ignore_index, the loss is set to 0.0.
322
+ inplace_backward: bool.
323
+ If True, we do the backward pass in-place by modifying the logits.
324
+ This saves memory.
325
+ process_group:
326
+ if not None, we're doing Tensor Parallel: each process is responsible for
327
+ one part of the vocab. The loss will be aggregated across processes.
328
+ Returns:
329
+ losses: [batch,], float
330
+ z_losses: [batch,], float
331
+ """
332
+ return CrossEntropyLossFunction.apply(
333
+ logits,
334
+ target,
335
+ label_smoothing,
336
+ logit_scale,
337
+ lse_square_scale,
338
+ ignore_index,
339
+ inplace_backward,
340
+ process_group,
341
+ )
342
+
343
+
344
+ class FusedCrossEntropyLoss(nn.Module):
345
+ def __init__(
346
+ self,
347
+ ignore_index: int = -100,
348
+ reduction: str = "mean",
349
+ label_smoothing: float = 0.0,
350
+ logit_scale: float = 1.0,
351
+ lse_square_scale: float = 0.0,
352
+ inplace_backward: bool = False,
353
+ process_group: Any = None,
354
+ return_z_loss: bool = False,
355
+ ):
356
+ """
357
+ Arguments:
358
+ ignore_index: int. If target == ignore_index, the loss is set to 0.0.
359
+ label_smoothing: float
360
+ lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
361
+ This is also referred to as "z-loss".
362
+ inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
363
+ This saves memory.
364
+ process_group: if not None, we're doing Tensor Parallel: each process is responsible for
365
+ one part of the vocab. The loss will be aggregated across processes.
366
+ return_z_loss: bool. If True, we return the component of the loss contributed by
367
+ the lse_square_scale value. This value is only for logging and does not support
368
+ backprop.
369
+ """
370
+ super().__init__()
371
+ if reduction not in ["mean", "none", "sum"]:
372
+ raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'")
373
+ self.ignore_index = ignore_index
374
+ self.reduction = reduction
375
+ self.label_smoothing = label_smoothing
376
+ self.logit_scale = logit_scale
377
+ self.lse_square_scale = lse_square_scale
378
+ self.inplace_backward = inplace_backward
379
+ self.process_group = process_group
380
+ self.return_z_loss = return_z_loss
381
+
382
+ def forward(self, input, target):
383
+ """
384
+ Arguments:
385
+ input: (batch, vocab_size)
386
+ target: (batch,)
387
+ Returns:
388
+ losses: (batch,) if reduction is 'none', else (1,), dtype float
389
+ z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss)
390
+ """
391
+ assert input.is_cuda and target.is_cuda, "Only support CUDA tensors"
392
+ loss, z_loss = cross_entropy_loss(
393
+ input,
394
+ target,
395
+ label_smoothing=self.label_smoothing,
396
+ logit_scale=self.logit_scale,
397
+ lse_square_scale=self.lse_square_scale,
398
+ ignore_index=self.ignore_index,
399
+ inplace_backward=self.inplace_backward,
400
+ process_group=self.process_group,
401
+ )
402
+ if self.reduction == "mean":
403
+ loss = loss.sum() / (target != self.ignore_index).sum()
404
+ elif self.reduction == "sum":
405
+ loss = loss.sum()
406
+ else:
407
+ loss = loss
408
+
409
+ if not self.return_z_loss:
410
+ return loss
411
+
412
+ if self.reduction == "mean":
413
+ z_loss = z_loss.sum() / (target != self.ignore_index).sum()
414
+ elif self.reduction == "sum":
415
+ z_loss = z_loss.sum()
416
+ else:
417
+ z_loss = z_loss
418
+
419
+ return loss, z_loss
fla/modules/fused_norm_gate.py ADDED
@@ -0,0 +1,995 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import math
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import triton
13
+ import triton.language as tl
14
+
15
+ from fla.utils import get_multiprocessor_count, input_guard
16
+
17
+
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
21
+ for num_warps in [1, 2, 4, 8, 16, 32]
22
+ for num_stages in [2, 3, 4]
23
+ ],
24
+ key=['N', 'HAS_RESIDUAL', 'STORE_RESIDUAL_OUT', 'IS_RMS_NORM', 'HAS_BIAS'],
25
+ )
26
+ @triton.jit
27
+ def layer_norm_gated_fwd_kernel(
28
+ X, # pointer to the input
29
+ G, # pointer to the gate
30
+ Y, # pointer to the output
31
+ W, # pointer to the weights
32
+ B, # pointer to the biases
33
+ RESIDUAL, # pointer to the residual
34
+ RESIDUAL_OUT, # pointer to the residual
35
+ Mean, # pointer to the mean
36
+ Rstd, # pointer to the 1/std
37
+ N, # number of columns in X
38
+ eps, # epsilon to avoid division by zero
39
+ ACTIVATION: tl.constexpr,
40
+ IS_RMS_NORM: tl.constexpr,
41
+ BLOCK_N: tl.constexpr,
42
+ HAS_RESIDUAL: tl.constexpr,
43
+ STORE_RESIDUAL_OUT: tl.constexpr,
44
+ HAS_WEIGHT: tl.constexpr,
45
+ HAS_BIAS: tl.constexpr
46
+ ):
47
+ # Map the program id to the row of X and Y it should compute.
48
+ row = tl.program_id(0)
49
+ X += row * N
50
+ Y += row * N
51
+ G += row * N
52
+ if HAS_RESIDUAL:
53
+ RESIDUAL += row * N
54
+ if STORE_RESIDUAL_OUT:
55
+ RESIDUAL_OUT += row * N
56
+ # Compute mean and variance
57
+ cols = tl.arange(0, BLOCK_N)
58
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
59
+ if HAS_RESIDUAL:
60
+ residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
61
+ x += residual
62
+ if STORE_RESIDUAL_OUT:
63
+ tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
64
+ if not IS_RMS_NORM:
65
+ mean = tl.sum(x, axis=0) / N
66
+ tl.store(Mean + row, mean)
67
+ xbar = tl.where(cols < N, x - mean, 0.0)
68
+ var = tl.sum(xbar * xbar, axis=0) / N
69
+ else:
70
+ xbar = tl.where(cols < N, x, 0.0)
71
+ var = tl.sum(xbar * xbar, axis=0) / N
72
+ rstd = 1 / tl.sqrt(var + eps)
73
+ tl.store(Rstd + row, rstd)
74
+ # Normalize and apply linear transformation
75
+ mask = cols < N
76
+ if HAS_WEIGHT:
77
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
78
+ if HAS_BIAS:
79
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
80
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
81
+ y = x_hat * w if HAS_WEIGHT else x_hat
82
+ if HAS_BIAS:
83
+ y = y + b
84
+
85
+ # Swish output gate
86
+ g = tl.load(G + cols, mask=cols < N, other=0.0).to(tl.float32)
87
+ if ACTIVATION == 'swish':
88
+ y = y * g * tl.sigmoid(g)
89
+ elif ACTIVATION == 'silu':
90
+ y = y * g * tl.sigmoid(g)
91
+ elif ACTIVATION == 'sigmoid':
92
+ y = y * tl.sigmoid(g)
93
+
94
+ # Write output
95
+ tl.store(Y + cols, y, mask=mask)
96
+
97
+
98
+ def layer_norm_gated_fwd(
99
+ x: torch.Tensor,
100
+ g: torch.Tensor,
101
+ weight: torch.Tensor,
102
+ bias: torch.Tensor,
103
+ activation: str = 'swish',
104
+ eps: float = 1e-5,
105
+ residual: torch.Tensor = None,
106
+ out_dtype: torch.dtype = None,
107
+ residual_dtype: torch.dtype = None,
108
+ is_rms_norm: bool = False
109
+ ):
110
+ if residual is not None:
111
+ residual_dtype = residual.dtype
112
+ M, N = x.shape
113
+ if residual is not None:
114
+ assert residual.shape == (M, N)
115
+ if weight is not None:
116
+ assert weight.shape == (N,)
117
+ if bias is not None:
118
+ assert bias.shape == (N,)
119
+ # allocate output
120
+ y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
121
+ if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
122
+ residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
123
+ else:
124
+ residual_out = None
125
+ mean = torch.empty((M,), dtype=torch.float, device=x.device) if not is_rms_norm else None
126
+ rstd = torch.empty((M,), dtype=torch.float, device=x.device)
127
+ # Less than 64KB per feature: enqueue fused kernel
128
+ MAX_FUSED_SIZE = 65536 // x.element_size()
129
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
130
+ if N > BLOCK_N:
131
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
132
+ # heuristics for number of warps
133
+
134
+ layer_norm_gated_fwd_kernel[(M,)](
135
+ x,
136
+ g,
137
+ y,
138
+ weight,
139
+ bias,
140
+ residual,
141
+ residual_out,
142
+ mean,
143
+ rstd,
144
+ N,
145
+ eps,
146
+ ACTIVATION=activation,
147
+ IS_RMS_NORM=is_rms_norm,
148
+ BLOCK_N=BLOCK_N,
149
+ HAS_RESIDUAL=residual is not None,
150
+ STORE_RESIDUAL_OUT=residual_out is not None,
151
+ HAS_WEIGHT=weight is not None,
152
+ HAS_BIAS=bias is not None,
153
+ )
154
+ # residual_out is None if residual is None and residual_dtype == input_dtype
155
+ return y, mean, rstd, residual_out if residual_out is not None else x
156
+
157
+
158
+ @triton.heuristics({
159
+ 'RECOMPUTE_OUTPUT': lambda args: args["Y"] is not None
160
+ })
161
+ @triton.autotune(
162
+ configs=[
163
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
164
+ for num_warps in [1, 2, 4, 8, 16, 32]
165
+ for num_stages in [2, 3, 4]
166
+ ],
167
+ key=['N', 'HAS_DRESIDUAL', 'STORE_DRESIDUAL', 'IS_RMS_NORM', 'HAS_BIAS'],
168
+ )
169
+ @triton.jit
170
+ def layer_norm_gated_bwd_kernel(
171
+ X, # pointer to the input
172
+ G, # pointer to the gate
173
+ W, # pointer to the weights
174
+ B, # pointer to the biases
175
+ Y, # pointer to the output to be recomputed
176
+ DY, # pointer to the output gradient
177
+ DX, # pointer to the input gradient
178
+ DG, # pointer to the gate gradient
179
+ DW, # pointer to the partial sum of weights gradient
180
+ DB, # pointer to the partial sum of biases gradient
181
+ DRESIDUAL,
182
+ DRESIDUAL_IN,
183
+ Mean, # pointer to the mean
184
+ Rstd, # pointer to the 1/std
185
+ M, # number of rows in X
186
+ N, # number of columns in X
187
+ eps, # epsilon to avoid division by zero
188
+ rows_per_program,
189
+ ACTIVATION: tl.constexpr,
190
+ IS_RMS_NORM: tl.constexpr,
191
+ BLOCK_N: tl.constexpr,
192
+ HAS_DRESIDUAL: tl.constexpr,
193
+ STORE_DRESIDUAL: tl.constexpr,
194
+ HAS_WEIGHT: tl.constexpr,
195
+ HAS_BIAS: tl.constexpr,
196
+ RECOMPUTE_OUTPUT: tl.constexpr,
197
+ ):
198
+ # Map the program id to the elements of X, DX, and DY it should compute.
199
+ row_block_id = tl.program_id(0)
200
+ row_start = row_block_id * rows_per_program
201
+ cols = tl.arange(0, BLOCK_N)
202
+ mask = cols < N
203
+ X += row_start * N
204
+ G += row_start * N
205
+ if HAS_DRESIDUAL:
206
+ DRESIDUAL += row_start * N
207
+ if STORE_DRESIDUAL:
208
+ DRESIDUAL_IN += row_start * N
209
+ DY += row_start * N
210
+ DX += row_start * N
211
+ DG += row_start * N
212
+ if RECOMPUTE_OUTPUT:
213
+ Y += row_start * N
214
+ if HAS_WEIGHT:
215
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
216
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
217
+ if HAS_BIAS:
218
+ b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
219
+ if HAS_BIAS:
220
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
221
+
222
+ row_end = min((row_block_id + 1) * rows_per_program, M)
223
+ for row in range(row_start, row_end):
224
+ # Load data to SRAM
225
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
226
+ g = tl.load(G + cols, mask=mask, other=0).to(tl.float32)
227
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
228
+
229
+ if not IS_RMS_NORM:
230
+ mean = tl.load(Mean + row)
231
+ rstd = tl.load(Rstd + row)
232
+ # Compute dx
233
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
234
+ xhat = tl.where(mask, xhat, 0.0)
235
+
236
+ y = xhat * w if HAS_WEIGHT else xhat
237
+ if HAS_BIAS:
238
+ y = y + b
239
+ if RECOMPUTE_OUTPUT:
240
+ tl.store(Y + cols, y, mask=mask)
241
+
242
+ sigmoid_g = tl.sigmoid(g)
243
+ if ACTIVATION == 'swish':
244
+ dg = dy * y * (sigmoid_g + g * sigmoid_g * (1 - sigmoid_g))
245
+ dy = dy * g * sigmoid_g
246
+ elif ACTIVATION == 'silu':
247
+ dg = dy * y * (sigmoid_g + g * sigmoid_g * (1 - sigmoid_g))
248
+ dy = dy * g * sigmoid_g
249
+ elif ACTIVATION == 'sigmoid':
250
+ dg = dy * y * sigmoid_g * (1 - sigmoid_g)
251
+ dy = dy * sigmoid_g
252
+ wdy = dy
253
+ if HAS_WEIGHT:
254
+ wdy = dy * w
255
+ dw += dy * xhat
256
+ if HAS_BIAS:
257
+ db += dy
258
+ if not IS_RMS_NORM:
259
+ c1 = tl.sum(xhat * wdy, axis=0) / N
260
+ c2 = tl.sum(wdy, axis=0) / N
261
+ dx = (wdy - (xhat * c1 + c2)) * rstd
262
+ else:
263
+ c1 = tl.sum(xhat * wdy, axis=0) / N
264
+ dx = (wdy - xhat * c1) * rstd
265
+ if HAS_DRESIDUAL:
266
+ dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
267
+ dx += dres
268
+ # Write dx
269
+ if STORE_DRESIDUAL:
270
+ tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
271
+ tl.store(DX + cols, dx, mask=mask)
272
+ tl.store(DG + cols, dg, mask=mask)
273
+
274
+ X += N
275
+ G += N
276
+ if HAS_DRESIDUAL:
277
+ DRESIDUAL += N
278
+ if STORE_DRESIDUAL:
279
+ DRESIDUAL_IN += N
280
+ if RECOMPUTE_OUTPUT:
281
+ Y += N
282
+ DY += N
283
+ DX += N
284
+ DG += N
285
+ if HAS_WEIGHT:
286
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
287
+ if HAS_BIAS:
288
+ tl.store(DB + row_block_id * N + cols, db, mask=mask)
289
+
290
+
291
+ def layer_norm_gated_bwd(
292
+ dy: torch.Tensor,
293
+ x: torch.Tensor,
294
+ g: torch.Tensor,
295
+ weight: torch.Tensor,
296
+ bias: torch.Tensor,
297
+ activation: str = 'swish',
298
+ eps: float = 1e-5,
299
+ mean: torch.Tensor = None,
300
+ rstd: torch.Tensor = None,
301
+ dresidual: torch.Tensor = None,
302
+ has_residual: bool = False,
303
+ is_rms_norm: bool = False,
304
+ x_dtype: torch.dtype = None,
305
+ recompute_output: bool = False,
306
+ ):
307
+ M, N = x.shape
308
+ assert dy.shape == (M, N)
309
+ if dresidual is not None:
310
+ assert dresidual.shape == (M, N)
311
+ if weight is not None:
312
+ assert weight.shape == (N,)
313
+ if bias is not None:
314
+ assert bias.shape == (N,)
315
+ # allocate output
316
+ dx = torch.empty_like(x) if x_dtype is None else torch.empty(M, N, dtype=x_dtype, device=x.device)
317
+ dg = torch.empty_like(g) if x_dtype is None else torch.empty(M, N, dtype=x_dtype, device=x.device)
318
+ dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None
319
+ y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
320
+
321
+ # Less than 64KB per feature: enqueue fused kernel
322
+ MAX_FUSED_SIZE = 65536 // x.element_size()
323
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
324
+ if N > BLOCK_N:
325
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
326
+ sm_count = get_multiprocessor_count(x.device.index)
327
+ dw = torch.empty((sm_count, N), dtype=torch.float, device=weight.device) if weight is not None else None
328
+ db = torch.empty((sm_count, N), dtype=torch.float, device=bias.device) if bias is not None else None
329
+ rows_per_program = math.ceil(M / sm_count)
330
+ grid = (sm_count,)
331
+ layer_norm_gated_bwd_kernel[grid](
332
+ x,
333
+ g,
334
+ weight,
335
+ bias,
336
+ y,
337
+ dy,
338
+ dx,
339
+ dg,
340
+ dw,
341
+ db,
342
+ dresidual,
343
+ dresidual_in,
344
+ mean,
345
+ rstd,
346
+ M,
347
+ N,
348
+ eps,
349
+ rows_per_program,
350
+ ACTIVATION=activation,
351
+ IS_RMS_NORM=is_rms_norm,
352
+ BLOCK_N=BLOCK_N,
353
+ HAS_DRESIDUAL=dresidual is not None,
354
+ STORE_DRESIDUAL=dresidual_in is not None,
355
+ HAS_WEIGHT=weight is not None,
356
+ HAS_BIAS=bias is not None,
357
+ )
358
+ dw = dw.sum(0).to(weight.dtype) if weight is not None else None
359
+ db = db.sum(0).to(bias.dtype) if bias is not None else None
360
+ # Don't need to compute dresidual_in separately in this case
361
+ if has_residual and dx.dtype == x.dtype:
362
+ dresidual_in = dx
363
+ return (dx, dg, dw, db, dresidual_in) if not recompute_output else (dx, dg, dw, db, dresidual_in, y)
364
+
365
+
366
+ class LayerNormGatedFunction(torch.autograd.Function):
367
+
368
+ @staticmethod
369
+ @input_guard
370
+ def forward(
371
+ ctx,
372
+ x: torch.Tensor,
373
+ g: torch.Tensor,
374
+ weight: torch.Tensor,
375
+ bias: torch.Tensor,
376
+ activation: str,
377
+ residual: Optional[torch.Tensor] = None,
378
+ eps: float = 1e-6,
379
+ prenorm: bool = False,
380
+ residual_in_fp32: bool = False,
381
+ is_rms_norm: bool = False,
382
+ ):
383
+ x_shape_og = x.shape
384
+ g_shape_og = g.shape
385
+ # reshape input data into 2D tensor
386
+ x = x.reshape(-1, x.shape[-1])
387
+ g = g.reshape(-1, g.shape[-1])
388
+ if residual is not None:
389
+ assert residual.shape == x_shape_og
390
+ residual = residual.reshape(-1, residual.shape[-1])
391
+ residual_dtype = (
392
+ residual.dtype
393
+ if residual is not None
394
+ else (torch.float if residual_in_fp32 else None)
395
+ )
396
+ y, mean, rstd, residual_out = layer_norm_gated_fwd(
397
+ x=x,
398
+ g=g,
399
+ weight=weight,
400
+ bias=bias,
401
+ activation=activation,
402
+ eps=eps,
403
+ residual=residual,
404
+ residual_dtype=residual_dtype,
405
+ is_rms_norm=is_rms_norm
406
+ )
407
+ ctx.save_for_backward(residual_out, g, weight, bias, mean, rstd)
408
+ ctx.x_shape_og = x_shape_og
409
+ ctx.g_shape_og = g_shape_og
410
+ ctx.activation = activation
411
+ ctx.eps = eps
412
+ ctx.is_rms_norm = is_rms_norm
413
+ ctx.has_residual = residual is not None
414
+ ctx.prenorm = prenorm
415
+ ctx.x_dtype = x.dtype
416
+ y = y.reshape(x_shape_og)
417
+ return y if not prenorm else (y, residual_out.reshape(x_shape_og))
418
+
419
+ @staticmethod
420
+ @input_guard
421
+ def backward(ctx, dy, *args):
422
+ x, g, weight, bias, mean, rstd = ctx.saved_tensors
423
+ dy = dy.reshape(-1, dy.shape[-1])
424
+ assert dy.shape == x.shape
425
+ if ctx.prenorm:
426
+ dresidual = args[0]
427
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
428
+ assert dresidual.shape == x.shape
429
+ else:
430
+ dresidual = None
431
+ dx, dg, dw, db, dresidual_in = layer_norm_gated_bwd(
432
+ dy=dy,
433
+ x=x,
434
+ g=g,
435
+ weight=weight,
436
+ bias=bias,
437
+ activation=ctx.activation,
438
+ eps=ctx.eps,
439
+ mean=mean,
440
+ rstd=rstd,
441
+ dresidual=dresidual,
442
+ has_residual=ctx.has_residual,
443
+ is_rms_norm=ctx.is_rms_norm,
444
+ x_dtype=ctx.x_dtype,
445
+ )
446
+ return (
447
+ dx.reshape(ctx.x_shape_og),
448
+ dg.reshape(ctx.g_shape_og),
449
+ dw,
450
+ db,
451
+ None,
452
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
453
+ None,
454
+ None,
455
+ None,
456
+ None,
457
+ )
458
+
459
+
460
+ class LayerNormGatedLinearFunction(torch.autograd.Function):
461
+
462
+ @staticmethod
463
+ @input_guard
464
+ def forward(
465
+ ctx,
466
+ x: torch.Tensor,
467
+ g: torch.Tensor,
468
+ norm_weight: torch.Tensor,
469
+ norm_bias: torch.Tensor,
470
+ linear_weight: torch.Tensor,
471
+ linear_bias: torch.Tensor,
472
+ residual: Optional[torch.Tensor] = None,
473
+ eps: float = 1e-6,
474
+ prenorm: bool = False,
475
+ residual_in_fp32: bool = False,
476
+ is_rms_norm: bool = False,
477
+ ):
478
+ x_shape_og = x.shape
479
+ g_shape_og = g.shape
480
+ # reshape input data into 2D tensor
481
+ x = x.reshape(-1, x.shape[-1])
482
+ g = g.reshape(-1, g.shape[-1])
483
+ if residual is not None:
484
+ assert residual.shape == x_shape_og
485
+ residual = residual.reshape(-1, residual.shape[-1])
486
+ residual_dtype = (
487
+ residual.dtype
488
+ if residual is not None
489
+ else (torch.float if residual_in_fp32 else None)
490
+ )
491
+ y, mean, rstd, residual_out = layer_norm_gated_fwd(
492
+ x=x,
493
+ g=g,
494
+ weight=norm_weight,
495
+ bias=norm_bias,
496
+ eps=eps,
497
+ residual=residual,
498
+ residual_dtype=residual_dtype,
499
+ is_rms_norm=is_rms_norm
500
+ )
501
+ y = y.reshape(x_shape_og)
502
+ dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
503
+ linear_weight = linear_weight.to(dtype)
504
+ linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
505
+ out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
506
+ # We don't store y, will be recomputed in the backward pass to save memory
507
+ ctx.save_for_backward(residual_out, g, norm_weight, norm_bias, linear_weight, mean, rstd)
508
+ ctx.x_shape_og = x_shape_og
509
+ ctx.g_shape_og = g_shape_og
510
+ ctx.eps = eps
511
+ ctx.is_rms_norm = is_rms_norm
512
+ ctx.has_residual = residual is not None
513
+ ctx.prenorm = prenorm
514
+ ctx.x_dtype = x.dtype
515
+ ctx.linear_bias_is_none = linear_bias is None
516
+ return out if not prenorm else (out, residual_out.reshape(x_shape_og))
517
+
518
+ @staticmethod
519
+ @input_guard
520
+ def backward(ctx, dout, *args):
521
+ x, g, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
522
+ dout = dout.reshape(-1, dout.shape[-1])
523
+ dy = F.linear(dout, linear_weight.t())
524
+ dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
525
+ assert dy.shape == x.shape
526
+ if ctx.prenorm:
527
+ dresidual = args[0]
528
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
529
+ assert dresidual.shape == x.shape
530
+ else:
531
+ dresidual = None
532
+ dx, dg, dnorm_weight, dnorm_bias, dresidual_in, y = layer_norm_gated_bwd(
533
+ dy=dy,
534
+ x=x,
535
+ g=g,
536
+ norm_weight=norm_weight,
537
+ norm_bias=norm_bias,
538
+ eps=ctx.eps,
539
+ mean=mean,
540
+ rstd=rstd,
541
+ dresidual=dresidual,
542
+ has_residual=ctx.has_residual,
543
+ is_rms_norm=ctx.is_rms_norm,
544
+ x_dtype=ctx.x_dtype,
545
+ recompute_output=True,
546
+ )
547
+ dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
548
+ return (
549
+ dx.reshape(ctx.x_shape_og),
550
+ dg.reshape(ctx.g_shape_og),
551
+ dnorm_weight,
552
+ dnorm_bias,
553
+ dlinear_weight,
554
+ dlinear_bias,
555
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
556
+ None,
557
+ None,
558
+ None,
559
+ None,
560
+ )
561
+
562
+
563
+ def layer_norm_gated(
564
+ x: torch.Tensor,
565
+ g: torch.Tensor,
566
+ weight: torch.Tensor,
567
+ bias: torch.Tensor,
568
+ activation: str = 'swish',
569
+ residual: Optional[torch.Tensor] = None,
570
+ prenorm: bool = False,
571
+ residual_in_fp32: bool = False,
572
+ eps: float = 1e-6
573
+ ):
574
+ return LayerNormGatedFunction.apply(
575
+ x,
576
+ g,
577
+ weight,
578
+ bias,
579
+ activation,
580
+ residual,
581
+ eps,
582
+ prenorm,
583
+ residual_in_fp32,
584
+ False
585
+ )
586
+
587
+
588
+ def rms_norm_gated(
589
+ x: torch.Tensor,
590
+ g: torch.Tensor,
591
+ weight: torch.Tensor,
592
+ bias: torch.Tensor,
593
+ activation: str = 'swish',
594
+ residual: Optional[torch.Tensor] = None,
595
+ prenorm: bool = False,
596
+ residual_in_fp32: bool = False,
597
+ eps: float = 1e-6
598
+ ):
599
+ return LayerNormGatedFunction.apply(
600
+ x,
601
+ g,
602
+ weight,
603
+ bias,
604
+ activation,
605
+ residual,
606
+ eps,
607
+ prenorm,
608
+ residual_in_fp32,
609
+ True
610
+ )
611
+
612
+
613
+ def layer_norm_swish_gate_linear(
614
+ x: torch.Tensor,
615
+ g: torch.Tensor,
616
+ norm_weight: torch.Tensor,
617
+ norm_bias: torch.Tensor,
618
+ linear_weight: torch.Tensor,
619
+ linear_bias: torch.Tensor,
620
+ residual: Optional[torch.Tensor] = None,
621
+ prenorm: bool = False,
622
+ residual_in_fp32: bool = False,
623
+ eps: float = 1e-6
624
+ ):
625
+ return LayerNormGatedLinearFunction.apply(
626
+ x,
627
+ g,
628
+ norm_weight,
629
+ norm_bias,
630
+ linear_weight,
631
+ linear_bias,
632
+ residual,
633
+ eps,
634
+ prenorm,
635
+ residual_in_fp32,
636
+ False
637
+ )
638
+
639
+
640
+ def rms_norm_swish_gate_linear(
641
+ x,
642
+ g: torch.Tensor,
643
+ norm_weight: torch.Tensor,
644
+ norm_bias: torch.Tensor,
645
+ linear_weight: torch.Tensor,
646
+ linear_bias: torch.Tensor,
647
+ residual: Optional[torch.Tensor] = None,
648
+ prenorm: bool = False,
649
+ residual_in_fp32: bool = False,
650
+ eps: float = 1e-6
651
+ ):
652
+ return LayerNormGatedLinearFunction.apply(
653
+ x,
654
+ g,
655
+ norm_weight,
656
+ norm_bias,
657
+ linear_weight,
658
+ linear_bias,
659
+ residual,
660
+ eps,
661
+ prenorm,
662
+ residual_in_fp32,
663
+ True
664
+ )
665
+
666
+
667
+ class FusedLayerNormGated(nn.Module):
668
+
669
+ def __init__(
670
+ self,
671
+ hidden_size: int,
672
+ elementwise_affine: bool = True,
673
+ bias: bool = False,
674
+ activation: str = 'swish',
675
+ eps: float = 1e-5,
676
+ device: Optional[torch.device] = None,
677
+ dtype: Optional[torch.dtype] = None,
678
+ ) -> FusedLayerNormGated:
679
+ factory_kwargs = {"device": device, "dtype": dtype}
680
+ super().__init__()
681
+
682
+ self.hidden_size = hidden_size
683
+ self.elementwise_affine = elementwise_affine
684
+ self.eps = eps
685
+ self.activation = activation
686
+
687
+ if self.activation not in ['swish', 'silu', 'sigmoid']:
688
+ raise ValueError(f"Unsupported activation: {self.activation}")
689
+
690
+ self.register_parameter("weight", None)
691
+ self.register_parameter("bias", None)
692
+ if elementwise_affine:
693
+ self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
694
+ if bias:
695
+ self.bias = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
696
+
697
+ self.reset_parameters()
698
+
699
+ def reset_parameters(self):
700
+ if self.elementwise_affine:
701
+ nn.init.ones_(self.weight)
702
+ if self.bias is not None:
703
+ nn.init.zeros_(self.bias)
704
+
705
+ def __repr__(self) -> str:
706
+ s = f"{self.__class__.__name__}({self.hidden_size}"
707
+ if not self.elementwise_affine:
708
+ s += f", elementwise_affine={self.elementwise_affine}"
709
+ s += f", eps={self.eps}"
710
+ s += f", activation={self.activation}"
711
+ s += ")"
712
+ return s
713
+
714
+ def forward(
715
+ self,
716
+ x: torch.Tensor,
717
+ g: torch.Tensor,
718
+ residual: Optional[torch.Tensor] = None,
719
+ prenorm: bool = False,
720
+ residual_in_fp32: bool = False
721
+ ) -> torch.Tensor:
722
+ return layer_norm_gated(
723
+ x,
724
+ g,
725
+ self.weight,
726
+ self.bias,
727
+ self.activation,
728
+ residual=residual,
729
+ eps=self.eps,
730
+ prenorm=prenorm,
731
+ residual_in_fp32=residual_in_fp32
732
+ )
733
+
734
+
735
+ class FusedRMSNormGated(nn.Module):
736
+
737
+ def __init__(
738
+ self,
739
+ hidden_size: int,
740
+ elementwise_affine: bool = True,
741
+ eps: float = 1e-5,
742
+ activation: str = 'swish',
743
+ device: Optional[torch.device] = None,
744
+ dtype: Optional[torch.dtype] = None,
745
+ ) -> FusedRMSNormGated:
746
+ factory_kwargs = {"device": device, "dtype": dtype}
747
+ super().__init__()
748
+
749
+ self.hidden_size = hidden_size
750
+ self.elementwise_affine = elementwise_affine
751
+ self.eps = eps
752
+ self.activation = activation
753
+
754
+ if self.activation not in ['swish', 'silu', 'sigmoid']:
755
+ raise ValueError(f"Unsupported activation: {self.activation}")
756
+
757
+ if elementwise_affine:
758
+ self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
759
+ else:
760
+ self.register_parameter("weight", None)
761
+ self.register_parameter("bias", None)
762
+
763
+ self.reset_parameters()
764
+
765
+ def reset_parameters(self):
766
+ if self.elementwise_affine:
767
+ nn.init.ones_(self.weight)
768
+
769
+ def __repr__(self) -> str:
770
+ s = f"{self.__class__.__name__}({self.hidden_size}"
771
+ if not self.elementwise_affine:
772
+ s += f", elementwise_affine={self.elementwise_affine}"
773
+ s += f", eps={self.eps}"
774
+ s += f", activation={self.activation}"
775
+ s += ")"
776
+ return s
777
+
778
+ def forward(
779
+ self,
780
+ x: torch.Tensor,
781
+ g: torch.Tensor,
782
+ residual: Optional[torch.Tensor] = None,
783
+ prenorm: bool = False,
784
+ residual_in_fp32: bool = False
785
+ ) -> torch.Tensor:
786
+ return rms_norm_gated(
787
+ x,
788
+ g,
789
+ self.weight,
790
+ self.bias,
791
+ self.activation,
792
+ residual=residual,
793
+ eps=self.eps,
794
+ prenorm=prenorm,
795
+ residual_in_fp32=residual_in_fp32
796
+ )
797
+
798
+
799
+ class FusedLayerNormSwishGate(FusedLayerNormGated):
800
+
801
+ def __init__(
802
+ self,
803
+ hidden_size: int,
804
+ elementwise_affine: bool = True,
805
+ bias: bool = False,
806
+ eps: float = 1e-5,
807
+ device: Optional[torch.device] = None,
808
+ dtype: Optional[torch.dtype] = None,
809
+ ) -> FusedLayerNormSwishGate:
810
+ super().__init__(
811
+ hidden_size=hidden_size,
812
+ elementwise_affine=elementwise_affine,
813
+ bias=bias,
814
+ eps=eps,
815
+ device=device,
816
+ dtype=dtype
817
+ )
818
+
819
+
820
+ class FusedRMSNormSwishGate(FusedRMSNormGated):
821
+
822
+ def __init__(
823
+ self,
824
+ hidden_size: int,
825
+ elementwise_affine: bool = True,
826
+ eps: float = 1e-5,
827
+ device: Optional[torch.device] = None,
828
+ dtype: Optional[torch.dtype] = None,
829
+ ) -> FusedRMSNormSwishGate:
830
+ super().__init__(
831
+ hidden_size=hidden_size,
832
+ elementwise_affine=elementwise_affine,
833
+ eps=eps,
834
+ device=device,
835
+ dtype=dtype
836
+ )
837
+
838
+
839
+ class FusedLayerNormGatedLinear(nn.Module):
840
+
841
+ def __init__(
842
+ self,
843
+ hidden_size: int,
844
+ elementwise_affine: bool = True,
845
+ eps: float = 1e-5,
846
+ device: Optional[torch.device] = None,
847
+ dtype: Optional[torch.dtype] = None,
848
+ ) -> FusedLayerNormGatedLinear:
849
+ factory_kwargs = {"device": device, "dtype": dtype}
850
+ super().__init__()
851
+
852
+ self.hidden_size = hidden_size
853
+ self.elementwise_affine = elementwise_affine
854
+ self.eps = eps
855
+
856
+ if elementwise_affine:
857
+ self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
858
+ else:
859
+ self.register_parameter("weight", None)
860
+ self.register_parameter("bias", None)
861
+
862
+ self.reset_parameters()
863
+
864
+ def reset_parameters(self):
865
+ if self.elementwise_affine:
866
+ nn.init.ones_(self.weight)
867
+
868
+ def __repr__(self) -> str:
869
+ s = f"{self.__class__.__name__}({self.hidden_size}"
870
+ if not self.elementwise_affine:
871
+ s += f", elementwise_affine={self.elementwise_affine}"
872
+ s += f", eps={self.eps}"
873
+ s += ")"
874
+ return s
875
+
876
+ def forward(
877
+ self,
878
+ x: torch.Tensor,
879
+ g: torch.Tensor,
880
+ weight: Optional[torch.Tensor] = None,
881
+ bias: Optional[torch.Tensor] = None,
882
+ residual: Optional[torch.Tensor] = None,
883
+ prenorm: bool = False,
884
+ residual_in_fp32: bool = False
885
+ ) -> torch.Tensor:
886
+ return layer_norm_swish_gate_linear(
887
+ x,
888
+ g,
889
+ self.weight,
890
+ self.bias,
891
+ weight,
892
+ bias,
893
+ residual=residual,
894
+ eps=self.eps,
895
+ prenorm=prenorm,
896
+ residual_in_fp32=residual_in_fp32
897
+ )
898
+
899
+
900
+ class FusedLayerNormSwishGateLinear(FusedLayerNormGatedLinear):
901
+
902
+ def __init__(
903
+ self,
904
+ hidden_size: int,
905
+ elementwise_affine: bool = True,
906
+ eps: float = 1e-5,
907
+ device: Optional[torch.device] = None,
908
+ dtype: Optional[torch.dtype] = None,
909
+ ) -> FusedLayerNormSwishGateLinear:
910
+ super().__init__(
911
+ hidden_size=hidden_size,
912
+ elementwise_affine=elementwise_affine,
913
+ eps=eps,
914
+ device=device,
915
+ dtype=dtype
916
+ )
917
+
918
+
919
+ class FusedRMSNormGatedLinear(nn.Module):
920
+
921
+ def __init__(
922
+ self,
923
+ hidden_size,
924
+ elementwise_affine: bool = True,
925
+ eps: float = 1e-5,
926
+ device: Optional[torch.device] = None,
927
+ dtype: Optional[torch.dtype] = None,
928
+ ) -> FusedRMSNormGatedLinear:
929
+ factory_kwargs = {"device": device, "dtype": dtype}
930
+ super().__init__()
931
+
932
+ self.hidden_size = hidden_size
933
+ self.elementwise_affine = elementwise_affine
934
+ self.eps = eps
935
+
936
+ self.register_parameter("weight", None)
937
+ self.register_parameter("bias", None)
938
+ if elementwise_affine:
939
+ self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
940
+
941
+ self.reset_parameters()
942
+
943
+ def reset_parameters(self):
944
+ if self.elementwise_affine:
945
+ nn.init.ones_(self.weight)
946
+
947
+ def __repr__(self) -> str:
948
+ s = f"{self.__class__.__name__}({self.hidden_size}"
949
+ if not self.elementwise_affine:
950
+ s += f", elementwise_affine={self.elementwise_affine}"
951
+ s += f", eps={self.eps}"
952
+ s += ")"
953
+ return s
954
+
955
+ def forward(
956
+ self,
957
+ x: torch.Tensor,
958
+ g: torch.Tensor,
959
+ weight: Optional[torch.Tensor] = None,
960
+ bias: Optional[torch.Tensor] = None,
961
+ residual: Optional[torch.Tensor] = None,
962
+ prenorm: bool = False,
963
+ residual_in_fp32: bool = False
964
+ ) -> torch.Tensor:
965
+ return rms_norm_swish_gate_linear(
966
+ x,
967
+ g,
968
+ self.weight,
969
+ self.bias,
970
+ weight,
971
+ bias,
972
+ residual=residual,
973
+ eps=self.eps,
974
+ prenorm=prenorm,
975
+ residual_in_fp32=residual_in_fp32
976
+ )
977
+
978
+
979
+ class FusedRMSNormSwishGateLinear(FusedRMSNormGatedLinear):
980
+
981
+ def __init__(
982
+ self,
983
+ hidden_size: int,
984
+ elementwise_affine: bool = True,
985
+ eps: float = 1e-5,
986
+ device: Optional[torch.device] = None,
987
+ dtype: Optional[torch.dtype] = None,
988
+ ) -> FusedRMSNormSwishGateLinear:
989
+ super().__init__(
990
+ hidden_size=hidden_size,
991
+ elementwise_affine=elementwise_affine,
992
+ eps=eps,
993
+ device=device,
994
+ dtype=dtype
995
+ )
fla/modules/grpo.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py
4
+ """
5
+ # Get the per-token log probabilities for the completions for the model and the reference model
6
+ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
7
+ # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
8
+ logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
9
+ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
10
+
11
+ input_ids = input_ids[:, -logits_to_keep:]
12
+ # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
13
+ # See https://github.com/huggingface/trl/issues/2770
14
+ logits = logits[:, -logits_to_keep:]
15
+ return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
16
+
17
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
18
+ if return_outputs:
19
+ raise ValueError("The GRPOTrainer does not support returning outputs")
20
+ # Compute the per-token log probabilities for the model
21
+
22
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
23
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
24
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
25
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
26
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
27
+
28
+ per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
29
+
30
+ # Compute the KL divergence between the model and the reference model
31
+ ref_per_token_logps = inputs["ref_per_token_logps"]
32
+ per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
33
+
34
+ # x - x.detach() allows for preserving gradients from x
35
+ advantages = inputs["advantages"]
36
+ per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
37
+ per_token_loss = -(per_token_loss - self.beta * per_token_kl)
38
+ loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
39
+
40
+ # Log the metrics
41
+ completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
42
+ self._metrics["completion_length"].append(completion_length)
43
+
44
+ mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
45
+ self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
46
+
47
+ return loss
48
+ """
49
+
50
+
51
+ import torch
52
+ import triton
53
+ import triton.language as tl
54
+
55
+ from fla.ops.utils.op import exp, log
56
+ from fla.utils import input_guard
57
+
58
+
59
+ @triton.autotune(
60
+ [triton.Config({'BLOCK_SIZE': BLOCK_SIZE}, num_warps=NUM_WARPS, num_stages=NUM_STAGES)
61
+ for BLOCK_SIZE in [1024, 2048, 4096, 8192]
62
+ for NUM_WARPS in [8, 16, 32]
63
+ for NUM_STAGES in [1, 2, 4]
64
+ ], key=['B', 'N']
65
+ )
66
+ @triton.jit
67
+ def grpo_fwd_kernel(
68
+ logits_ptr,
69
+ ref_logp_ptr,
70
+ input_ids_ptr,
71
+ advantages_ptr,
72
+ completion_mask_ptr,
73
+ loss_ptr,
74
+ lse_ptr,
75
+ beta,
76
+ save_kl: tl.constexpr,
77
+ B,
78
+ M,
79
+ N,
80
+ L,
81
+ start_idx,
82
+ BLOCK_SIZE: tl.constexpr
83
+ ):
84
+ row_idx = tl.program_id(0)
85
+
86
+ off_b = row_idx // L
87
+ N = tl.cast(N, tl.int64)
88
+
89
+ loss_ptr += row_idx
90
+
91
+ completion_mask_ptr += row_idx
92
+ not_skip = tl.load(completion_mask_ptr).to(tl.int1)
93
+ if not_skip == 1:
94
+ ref_logp_ptr += row_idx
95
+ lse_ptr += row_idx
96
+ advantages_ptr += off_b
97
+ logits_ptr += N * (row_idx + off_b)
98
+ input_ids_ptr += row_idx + (off_b+1) * start_idx
99
+ base_cols = tl.arange(0, BLOCK_SIZE)
100
+
101
+ m_i = -float("inf")
102
+ l_i = 0.0
103
+ for start_n in tl.range(0, N, BLOCK_SIZE):
104
+ cols = start_n + base_cols
105
+ mask = cols < N
106
+ logits = tl.load(logits_ptr+cols, mask=mask, other=-float('inf')).to(tl.float32)
107
+ m_ij = tl.max(logits)
108
+ new_m_i = tl.maximum(m_i, m_ij)
109
+ l_i = l_i * exp(m_i - new_m_i) + tl.sum(exp(logits - new_m_i))
110
+ m_i = new_m_i
111
+ lse = log(l_i) + m_i
112
+
113
+ idx = tl.load(input_ids_ptr)
114
+ x = tl.load(logits_ptr+idx).to(tl.float32)
115
+ advantage = tl.load(advantages_ptr).to(tl.float32)
116
+ ref_logp = tl.load(ref_logp_ptr)
117
+ logp = x - lse
118
+ diff = ref_logp - logp
119
+ kl = exp(diff) - diff - 1
120
+ loss = kl * beta - advantage
121
+
122
+ tl.store(loss_ptr, loss.to(loss_ptr.dtype.element_ty))
123
+ tl.store(lse_ptr, lse.to(lse_ptr.dtype.element_ty))
124
+ if save_kl:
125
+ tl.store(loss_ptr+M, kl.to(loss_ptr.dtype.element_ty))
126
+ else:
127
+ # store 0
128
+ tl.store(loss_ptr, 0.0)
129
+ if save_kl:
130
+ tl.store(loss_ptr+M, 0.0)
131
+
132
+
133
+ @triton.autotune(
134
+ [triton.Config({'BLOCK_SIZE': BLOCK_SIZE}, num_warps=NUM_WARPS, num_stages=NUM_STAGES)
135
+ for BLOCK_SIZE in [1024, 2048, 4096, 8192]
136
+ for NUM_WARPS in [8, 16, 32]
137
+ for NUM_STAGES in [1, 2, 4]
138
+ ], key=['B', 'N']
139
+ )
140
+ @triton.jit
141
+ def grpo_bwd_kernel(
142
+ dloss_ptr,
143
+ dlogits_ptr,
144
+ logits_ptr,
145
+ ref_logp_ptr,
146
+ input_ids_ptr,
147
+ advantages_ptr,
148
+ completion_mask_ptr,
149
+ lse_ptr,
150
+ beta,
151
+ B,
152
+ N,
153
+ L,
154
+ start_idx,
155
+ BLOCK_SIZE: tl.constexpr
156
+ ):
157
+
158
+ row_idx = tl.program_id(0) # B*L
159
+ off_b = row_idx // L
160
+
161
+ N = tl.cast(N, tl.int64)
162
+
163
+ dlogits_ptr += N * (row_idx + off_b)
164
+ base_cols = tl.arange(0, BLOCK_SIZE)
165
+ completion_mask_ptr += row_idx
166
+ not_skip = tl.load(completion_mask_ptr).to(tl.int1)
167
+
168
+ if not_skip == 1:
169
+ lse_ptr += row_idx
170
+ dloss_ptr += row_idx
171
+ advantages_ptr += off_b
172
+ ref_logp_ptr += row_idx
173
+ logits_ptr += N * (row_idx + off_b)
174
+ input_ids_ptr += row_idx + (off_b+1) * start_idx
175
+ dloss = tl.load(dloss_ptr).to(tl.float32)
176
+ lse = tl.load(lse_ptr).to(tl.float32)
177
+ idx = tl.load(input_ids_ptr)
178
+ x = tl.load(logits_ptr+idx).to(tl.float32)
179
+ advantage = tl.load(advantages_ptr).to(tl.float32)
180
+ ref_logp = tl.load(ref_logp_ptr)
181
+ logp = x - lse
182
+
183
+ dlogp = (beta * (-1.0 * exp(ref_logp - logp) + 1)
184
+ - advantage) * dloss
185
+
186
+ for start_n in tl.range(0, N, BLOCK_SIZE):
187
+ cols = start_n + base_cols
188
+ mask = cols < N
189
+ logits = tl.load(logits_ptr+cols, mask=mask, other=-float('inf')).to(tl.float32)
190
+ probs = exp(logits - lse)
191
+ dlogits = tl.where(cols == idx, 1-probs, -probs) * dlogp
192
+
193
+ tl.store(dlogits_ptr+cols, dlogits.to(dlogits_ptr.dtype.element_ty), mask=mask)
194
+ else:
195
+ dlogits = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
196
+ for start_n in tl.range(0, N, BLOCK_SIZE):
197
+ cols = start_n + base_cols
198
+ mask = cols < N
199
+
200
+ tl.store(dlogits_ptr+cols, dlogits.to(dlogits_ptr.dtype.element_ty), mask=mask)
201
+
202
+
203
+ class GrpoLoss(torch.autograd.Function):
204
+
205
+ @input_guard
206
+ @staticmethod
207
+ def forward(ctx, logits, ref_logp, input_ids, advantages, beta, completion_mask, save_kl):
208
+ ctx.input_shape = logits.shape
209
+ B, L_ADD_1, N = ctx.input_shape
210
+ L = L_ADD_1 - 1
211
+ M = B * L
212
+ input_ids_start_index = input_ids.size(1) - L
213
+
214
+ if not save_kl:
215
+ loss = torch.empty(B, L, device=logits.device, dtype=torch.float32)
216
+ else:
217
+ loss = torch.empty(B*2, L, device=logits.device, dtype=torch.float32)
218
+
219
+ lse = torch.empty(B, L, device=logits.device, dtype=torch.float32)
220
+
221
+ if completion_mask is None:
222
+ completion_mask = torch.ones(B, L, device=logits.device, dtype=torch.int32)
223
+ else:
224
+ loss[:B].masked_fill_(completion_mask.logical_not(), 0.0)
225
+
226
+ grpo_fwd_kernel[(M,)](
227
+ logits_ptr=logits,
228
+ ref_logp_ptr=ref_logp,
229
+ input_ids_ptr=input_ids,
230
+ advantages_ptr=advantages,
231
+ completion_mask_ptr=completion_mask,
232
+ loss_ptr=loss,
233
+ lse_ptr=lse,
234
+ beta=beta,
235
+ save_kl=save_kl,
236
+ B=B, M=M, N=N, L=L,
237
+ start_idx=input_ids_start_index,
238
+ )
239
+ ctx.beta = beta
240
+ ctx.save_for_backward(lse, logits, input_ids, advantages, completion_mask)
241
+ ctx.ref_logp = ref_logp
242
+ return loss
243
+
244
+ @input_guard
245
+ @staticmethod
246
+ def backward(ctx, dloss):
247
+ # The grad of logits comes from two parts, the reward part and the kl part
248
+ lse, logits, input_ids, advantages, completion_mask = ctx.saved_tensors
249
+ B, L_ADD_1, N = ctx.input_shape
250
+ L = L_ADD_1 - 1
251
+ M = B * L
252
+
253
+ input_ids_start_index = input_ids.size(1) - L
254
+
255
+ dlogits = torch.empty_like(logits) # B, L_ADD_1, N
256
+
257
+ grpo_bwd_kernel[(M,)](
258
+ dloss_ptr=dloss,
259
+ dlogits_ptr=dlogits,
260
+ logits_ptr=logits,
261
+ ref_logp_ptr=ctx.ref_logp,
262
+ input_ids_ptr=input_ids,
263
+ advantages_ptr=advantages,
264
+ completion_mask_ptr=completion_mask,
265
+ lse_ptr=lse,
266
+ beta=ctx.beta,
267
+ B=B, N=N, L=L,
268
+ start_idx=input_ids_start_index,
269
+ )
270
+ # The last token in the completion is not used in the loss computation
271
+ # and therefore its gradient should be set to 0
272
+ dlogits[:, -1, :].fill_(0.0)
273
+ return dlogits.view(*ctx.input_shape), None, None, None, None, None, None
274
+
275
+
276
+ def fused_grpo_loss(logits, ref_logp, input_ids, advantages, beta=0.1, completion_mask=None, save_kl=False) -> torch.Tensor:
277
+ '''
278
+ compute grpo loss, save memory(no addition usage) and fast speed(6X for A800)
279
+
280
+ Args:
281
+ logtits: Tensor, [B, L+1, vocab_size], the origin output of model, it's not logits[:, :-1]
282
+ ref_logp: Tensor, [B, L], the origin output of model, it's not ref_logits[:, :-1]
283
+ input_ids: Tensor, [B, K+L], it's prompt_completion_id, it contains the prompt ids and output ids
284
+ advantages: Tensor, [B], the advantages of each prompt
285
+ beta: float, the weight of kl loss
286
+ completion_mask: Tensor, loss mask
287
+ save_kl: bool, if true will save kl
288
+
289
+ Retutn:
290
+ loss: Tensor, [B, L], the loss of grpo, it contains the advantage part and kl part
291
+
292
+ NOTE: logits(ref_logits) is computed by these steps
293
+ logits_to_keep = completion_ids.size(1)
294
+
295
+ def get_per_token_logits(model, input_ids, attention_mask, logits_to_keep):
296
+ # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
297
+ logits = model(
298
+ input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1
299
+ ).logits
300
+ return logits
301
+
302
+ logits = get_per_token_logits(model, prompt_completion_ids, attention_mask, logits_to_keep)
303
+ '''
304
+ out = GrpoLoss.apply(logits, ref_logp, input_ids, advantages, beta, completion_mask, save_kl)
305
+ if not save_kl:
306
+ return out
307
+ else:
308
+ return out.chunk(2, axis=0)
309
+
310
+
311
+ def grpo_loss_torch(logits, ref_logp, input_ids, advantages, beta=0.1, completion_mask=None, save_kl=False):
312
+ def get_log_probs(logits, input_ids):
313
+ per_token_logps = []
314
+ for logits_row, input_ids_row in zip(logits, input_ids[:, -logits.size(1):]):
315
+ log_probs = logits_row.log_softmax(dim=-1)
316
+ token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
317
+ per_token_logps.append(token_log_prob)
318
+ return torch.stack(per_token_logps)
319
+
320
+ logits = logits[:, :-1]
321
+ per_token_logps = get_log_probs(logits, input_ids)
322
+ ref_per_token_logps = ref_logp
323
+ per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
324
+
325
+ per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
326
+ per_token_loss = -(per_token_loss - beta * per_token_kl)
327
+ if completion_mask is not None:
328
+ per_token_loss *= completion_mask
329
+ if save_kl:
330
+ per_token_kl *= completion_mask
331
+ return per_token_loss if not save_kl else (per_token_loss, per_token_kl)
332
+
333
+
334
+ @torch.compile(fullgraph=True)
335
+ def grpo_loss_with_old_logps(
336
+ logps: torch.Tensor,
337
+ ref_logps: torch.Tensor,
338
+ old_logps: torch.Tensor,
339
+ pad_mask: torch.Tensor,
340
+ logits_to_keep: int,
341
+ rewards: torch.Tensor,
342
+ beta: float = 0.2,
343
+ epsilon: float = 0.2
344
+ ):
345
+ """
346
+ Compute the GRPO (Group Relative Policy Optimization) loss.
347
+
348
+ Args:
349
+ logps (torch.Tensor): [Batch, Token_length] Log probabilities of the current policy.
350
+ ref_logps (torch.Tensor):[Batch, Token_length] Log probabilities of the reference policy.
351
+ old_logps (torch.Tensor): [Batch, Token_length] Log probabilities of the old policy.
352
+ completion_ids (torch.Tensor): [Batch, Token_length] Completion token IDs (bool).
353
+ pad_token_id: Pad token ID.
354
+ logits_to_keep (int): Number of logits to keep for masking.
355
+ rewards (torch.Tensor): [Batch] Rewards for each generation.
356
+ beta (float) = 0.2: A hyperparameter for weighting the KL divergence term.
357
+ epsilon (float) = 0.2: An float hyperparameter for clipping the importance weights.
358
+
359
+ Returns:
360
+ torch.Tensor: The computed GRPO loss.
361
+ """
362
+ B = logps.shape[0]
363
+ assert B > 1, "Batch * Num generations should be greater than 1"
364
+
365
+ rewards_shaped = rewards.view(-1, B) # B,num_generations
366
+ advantages = (rewards_shaped - rewards_shaped.mean(dim=1, keepdim=True)) / \
367
+ (rewards_shaped.std(dim=1, keepdim=True) + 1e-8)
368
+ advantages = advantages.view(-1) # B*num_generations
369
+ # Calculate the per - token KL divergence
370
+ per_token_kl = torch.exp(ref_logps - logps) - (ref_logps - logps) - 1
371
+
372
+ # Calculate the ratio of probabilities (importance weights)
373
+ # Importance weights are calculated as exp(log_pi_theta - log_pi_theta_old)
374
+ importance_weights = torch.exp(logps - old_logps)
375
+
376
+ # Clip the importance weights to the range [1 - epsilon, 1 + epsilon]
377
+ importance_weights_clipped = torch.clamp(importance_weights, 1 - epsilon, 1 + epsilon)
378
+
379
+ # Create a completion mask. It checks which positions are valid based on logits_to_keep
380
+ completion_mask = torch.arange(logits_to_keep, device=logps.device)[None, :] >= 0
381
+
382
+ # Combine the completion mask and padding mask
383
+ completion_mask = completion_mask & pad_mask # Ensure matching shape
384
+
385
+ # Add an extra dimension to advantages to match the shape for element - wise multiplication
386
+ advantages = advantages.unsqueeze(1)
387
+
388
+ # Calculate the per - token loss. It takes the minimum of the unclipped and clipped importance weights
389
+ # and subtracts the KL divergence term weighted by beta, then multiplies by the completion mask
390
+ token_loss = -(torch.min(advantages * importance_weights, advantages *
391
+ importance_weights_clipped) - beta * per_token_kl) * completion_mask
392
+
393
+ # Calculate the final loss by summing the token losses and normalizing by the number of valid tokens
394
+ loss = -token_loss.sum() / completion_mask.sum()
395
+
396
+ return loss
fla/modules/l2norm.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import input_guard
11
+
12
+
13
+ @triton.autotune(
14
+ configs=[
15
+ triton.Config({}, num_warps=num_warps)
16
+ for num_warps in [1, 2, 4, 8, 16, 32]
17
+ ],
18
+ key=['N']
19
+ )
20
+ @triton.jit
21
+ def l2norm_fwd_kernel(
22
+ X,
23
+ Y,
24
+ N,
25
+ eps,
26
+ BLOCK_N: tl.constexpr,
27
+ ):
28
+ i_m = tl.program_id(0)
29
+ X += i_m * N
30
+ Y += i_m * N
31
+ # Compute mean and variance
32
+ cols = tl.arange(0, BLOCK_N)
33
+ mask = cols < N
34
+ x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
35
+ xbar = tl.where(mask, x, 0.0)
36
+ var = tl.sum(xbar * xbar, axis=0)
37
+ rstd = 1 / tl.sqrt(var + eps)
38
+ # tl.store(Rstd + i_m, rstd)
39
+ # Normalize and apply linear transformation
40
+ y = x * rstd
41
+ # Write output
42
+ tl.store(Y + cols, y, mask=mask)
43
+
44
+
45
+ @triton.autotune(
46
+ configs=[
47
+ triton.Config({}, num_warps=num_warps)
48
+ for num_warps in [1, 2, 4, 8, 16, 32]
49
+ ],
50
+ key=['N']
51
+ )
52
+ @triton.jit
53
+ def l2norm_bwd_kernel(
54
+ X,
55
+ DY,
56
+ DX,
57
+ N,
58
+ eps,
59
+ BLOCK_N: tl.constexpr,
60
+ ):
61
+ i_m = tl.program_id(0)
62
+ X += i_m * N
63
+ DX += i_m * N
64
+ DY += i_m * N
65
+
66
+ # Y += i_m * stride_y_row
67
+ cols = tl.arange(0, BLOCK_N)
68
+ mask = cols < N
69
+ x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
70
+ x = tl.where(mask, x, 0.0)
71
+ var = tl.sum(x * x)
72
+ rstd = 1 / tl.sqrt(var + eps)
73
+ # tl.store(Rstd + i_m, rstd)
74
+ # Normalize and apply linear transformation
75
+ # y = x * rstd
76
+ dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32)
77
+ dy = tl.where(mask, dy, 0.0)
78
+ dx = dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x
79
+ tl.store(DX + cols, dx, mask=mask)
80
+
81
+
82
+ def l2norm_fwd(
83
+ x: torch.Tensor,
84
+ eps: float = 1e-6,
85
+ output_dtype: Optional[torch.dtype] = None
86
+ ):
87
+ x_shape_og = x.shape
88
+ x = x.reshape(-1, x.shape[-1])
89
+ # allocate output
90
+ if output_dtype is None:
91
+ y = torch.empty_like(x)
92
+ else:
93
+ y = torch.empty_like(x, dtype=output_dtype)
94
+ assert y.stride(-1) == 1
95
+ N = x.shape[-1]
96
+ M = x.shape[0]
97
+ # rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
98
+ # Less than 64KB per feature: enqueue fused kernel
99
+ MAX_FUSED_SIZE = 65536 // x.element_size()
100
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
101
+ if N > BLOCK_N:
102
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
103
+ # heuristics for number of warps
104
+ l2norm_fwd_kernel[(M,)](
105
+ x,
106
+ y,
107
+ N,
108
+ eps,
109
+ BLOCK_N,
110
+ )
111
+ return y.reshape(x_shape_og)
112
+
113
+
114
+ def l2norm_bwd(
115
+ x: torch.Tensor,
116
+ dy: torch.Tensor,
117
+ eps: float = 1e-5
118
+ ):
119
+ x_shape_og = x.shape
120
+ x = x.reshape(-1, dy.shape[-1])
121
+ dy = dy.reshape(-1, dy.shape[-1])
122
+ if dy.stride(-1) != 1:
123
+ dy = dy.contiguous()
124
+ assert dy.shape == x.shape
125
+ # allocate output
126
+ dx = torch.empty_like(x)
127
+ M = x.shape[0]
128
+ N = x.shape[-1]
129
+ # rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
130
+ # Less than 64KB per feature: enqueue fused kernel
131
+ MAX_FUSED_SIZE = 65536 // x.element_size()
132
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
133
+ if N > BLOCK_N:
134
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
135
+ # heuristics for number of warps
136
+ l2norm_bwd_kernel[(M,)](
137
+ x,
138
+ dy,
139
+ dx,
140
+ N,
141
+ eps,
142
+ BLOCK_N,
143
+ )
144
+ return dx.reshape(x_shape_og)
145
+
146
+
147
+ class L2NormFunction(torch.autograd.Function):
148
+
149
+ @staticmethod
150
+ @input_guard
151
+ def forward(
152
+ ctx,
153
+ x,
154
+ eps=1e-6,
155
+ output_dtype=None
156
+ ):
157
+ y = l2norm_fwd(x, eps, output_dtype)
158
+ ctx.eps = eps
159
+ ctx.x_dtype = x.dtype
160
+ ctx.save_for_backward(x)
161
+ return y
162
+
163
+ @staticmethod
164
+ @input_guard
165
+ def backward(ctx, dy):
166
+ x, = ctx.saved_tensors
167
+ dx = l2norm_bwd(x, dy, ctx.eps)
168
+ return dx, None, None
169
+
170
+
171
+ def l2_norm(
172
+ x: torch.Tensor,
173
+ eps: float = 1e-6,
174
+ output_dtype: Optional[torch.dtype] = None
175
+ ) -> torch.Tensor:
176
+ return L2NormFunction.apply(x, eps, output_dtype)
fla/modules/layernorm_gated.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+ # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
3
+ # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
4
+ # This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
5
+ # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
6
+
7
+ import math
8
+ from typing import Optional
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import triton
14
+ import triton.language as tl
15
+ from einops import rearrange
16
+
17
+ from fla.utils import get_multiprocessor_count, input_guard
18
+
19
+
20
+ def rms_norm_ref(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True):
21
+ dtype = x.dtype
22
+ weight = weight.float()
23
+ bias = bias.float() if bias is not None else None
24
+ if upcast:
25
+ x = x.float()
26
+ z = z.float() if z is not None else z
27
+ if z is not None and not norm_before_gate:
28
+ x = x * F.silu(z)
29
+ if group_size is None:
30
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
31
+ out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
32
+ else:
33
+ x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
34
+ rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
35
+ out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
36
+ if bias is not None:
37
+ out = out + bias
38
+ if z is not None and norm_before_gate:
39
+ out *= F.silu(z)
40
+ return out.to(dtype)
41
+
42
+
43
+ @triton.heuristics({
44
+ "HAS_BIAS": lambda args: args["B"] is not None,
45
+ "HAS_Z": lambda args: args["Z"] is not None,
46
+ })
47
+ @triton.jit
48
+ def layer_norm_fwd_kernel(
49
+ X, # pointer to the input
50
+ Y, # pointer to the output
51
+ W, # pointer to the weights
52
+ B, # pointer to the biases
53
+ Z, # pointer to the other branch
54
+ Mean, # pointer to the mean
55
+ Rstd, # pointer to the 1/std
56
+ stride_x_row, # how much to increase the pointer when moving by 1 row
57
+ stride_y_row,
58
+ stride_z_row,
59
+ M, # number of rows in X
60
+ N, # number of columns in X
61
+ eps, # epsilon to avoid division by zero
62
+ BLOCK_N: tl.constexpr,
63
+ HAS_BIAS: tl.constexpr,
64
+ HAS_Z: tl.constexpr,
65
+ NORM_BEFORE_GATE: tl.constexpr,
66
+ IS_RMS_NORM: tl.constexpr,
67
+ ):
68
+ # Map the program id to the row of X and Y it should compute.
69
+ row = tl.program_id(0)
70
+ group = tl.program_id(1)
71
+ X += row * stride_x_row + group * N
72
+ Y += row * stride_y_row + group * N
73
+ if HAS_Z:
74
+ Z += row * stride_z_row + group * N
75
+ if not IS_RMS_NORM:
76
+ Mean += group * M
77
+ Rstd += group * M
78
+ W += group * N
79
+ if HAS_BIAS:
80
+ B += group * N
81
+ # Compute mean and variance
82
+ cols = tl.arange(0, BLOCK_N)
83
+ x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
84
+ if HAS_Z and not NORM_BEFORE_GATE:
85
+ z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
86
+ x *= z * tl.sigmoid(z)
87
+ if not IS_RMS_NORM:
88
+ mean = tl.sum(x, axis=0) / N
89
+ tl.store(Mean + row, mean)
90
+ xbar = tl.where(cols < N, x - mean, 0.)
91
+ var = tl.sum(xbar * xbar, axis=0) / N
92
+ else:
93
+ xbar = tl.where(cols < N, x, 0.)
94
+ var = tl.sum(xbar * xbar, axis=0) / N
95
+ rstd = 1 / tl.sqrt(var + eps)
96
+ tl.store(Rstd + row, rstd)
97
+ # Normalize and apply linear transformation
98
+ mask = cols < N
99
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
100
+ if HAS_BIAS:
101
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
102
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
103
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
104
+ if HAS_Z and NORM_BEFORE_GATE:
105
+ z = tl.load(Z + cols, mask=mask).to(tl.float32)
106
+ y *= z * tl.sigmoid(z)
107
+ # Write output
108
+ tl.store(Y + cols, y, mask=mask)
109
+
110
+
111
+ def layer_norm_fwd(
112
+ x: torch.Tensor,
113
+ weight: torch.Tensor,
114
+ bias: torch.Tensor,
115
+ eps: float,
116
+ z: torch.Tensor = None,
117
+ out: torch.Tensor = None,
118
+ group_size: int = None,
119
+ norm_before_gate: bool = True,
120
+ is_rms_norm: bool = False,
121
+ ):
122
+ M, N = x.shape
123
+ if group_size is None:
124
+ group_size = N
125
+ assert N % group_size == 0
126
+ ngroups = N // group_size
127
+ assert x.stride(-1) == 1
128
+ if z is not None:
129
+ assert z.stride(-1) == 1
130
+ assert z.shape == (M, N)
131
+ assert weight.shape == (N,)
132
+ assert weight.stride(-1) == 1
133
+ if bias is not None:
134
+ assert bias.stride(-1) == 1
135
+ assert bias.shape == (N,)
136
+ # allocate output
137
+ if out is not None:
138
+ assert out.shape == x.shape
139
+ else:
140
+ out = torch.empty_like(x)
141
+ assert out.stride(-1) == 1
142
+ mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None
143
+ rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
144
+ # Less than 64KB per feature: enqueue fused kernel
145
+ MAX_FUSED_SIZE = 65536 // x.element_size()
146
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
147
+ if group_size > BLOCK_N:
148
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
149
+ # heuristics for number of warps
150
+ num_warps = min(max(BLOCK_N // 256, 1), 8)
151
+ grid = (M, ngroups)
152
+ layer_norm_fwd_kernel[grid](
153
+ x,
154
+ out,
155
+ weight,
156
+ bias,
157
+ z,
158
+ mean,
159
+ rstd,
160
+ x.stride(0),
161
+ out.stride(0),
162
+ z.stride(0) if z is not None else 0,
163
+ M,
164
+ group_size,
165
+ eps,
166
+ BLOCK_N=BLOCK_N,
167
+ NORM_BEFORE_GATE=norm_before_gate,
168
+ IS_RMS_NORM=is_rms_norm,
169
+ num_warps=num_warps
170
+ )
171
+ return out, mean, rstd
172
+
173
+
174
+ @triton.heuristics({
175
+ "HAS_BIAS": lambda args: args["B"] is not None,
176
+ "HAS_Z": lambda args: args["Z"] is not None,
177
+ "RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None,
178
+ })
179
+ @triton.jit
180
+ def layer_norm_bwd_kernel(
181
+ X, # pointer to the input
182
+ W, # pointer to the weights
183
+ B, # pointer to the biases
184
+ Z, # pointer to the other branch
185
+ Y, # pointer to the output to be recomputed
186
+ DY, # pointer to the output gradient
187
+ DX, # pointer to the input gradient
188
+ DW, # pointer to the partial sum of weights gradient
189
+ DB, # pointer to the partial sum of biases gradient
190
+ DZ, # pointer to the other branch
191
+ Mean, # pointer to the mean
192
+ Rstd, # pointer to the 1/std
193
+ stride_x_row, # how much to increase the pointer when moving by 1 row
194
+ stride_z_row,
195
+ stride_y_row,
196
+ stride_dy_row,
197
+ stride_dx_row,
198
+ stride_dz_row,
199
+ stride_dw_row,
200
+ stride_db_row,
201
+ M, # number of rows in X
202
+ N, # number of columns in X
203
+ eps, # epsilon to avoid division by zero
204
+ rows_per_program,
205
+ NORM_BEFORE_GATE: tl.constexpr,
206
+ IS_RMS_NORM: tl.constexpr,
207
+ HAS_BIAS: tl.constexpr,
208
+ HAS_Z: tl.constexpr,
209
+ RECOMPUTE_OUTPUT: tl.constexpr,
210
+ BLOCK_N: tl.constexpr,
211
+ ):
212
+ # Map the program id to the elements of X, DX, and DY it should compute.
213
+ row_block_id = tl.program_id(0)
214
+ group = tl.program_id(1)
215
+ row_start = row_block_id * rows_per_program
216
+ cols = tl.arange(0, BLOCK_N)
217
+ mask = cols < N
218
+ X += row_start * stride_x_row + group * N
219
+ if HAS_Z:
220
+ Z += row_start * stride_z_row + group * N
221
+ DZ += row_start * stride_dz_row + group * N
222
+ DY += row_start * stride_dy_row + group * N
223
+ DX += row_start * stride_dx_row + group * N
224
+ if RECOMPUTE_OUTPUT:
225
+ Y += row_start * stride_y_row + group * N
226
+ if not IS_RMS_NORM:
227
+ Mean += group * M
228
+ Rstd += group * M
229
+ W += group * N
230
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
231
+ if (RECOMPUTE_OUTPUT or HAS_Z) and HAS_BIAS:
232
+ B += group * N
233
+ b = tl.load(B + cols, mask=mask, other=0.).to(tl.float32)
234
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
235
+ if HAS_BIAS:
236
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
237
+ row_end = min((row_block_id + 1) * rows_per_program, M)
238
+ for row in range(row_start, row_end):
239
+ # Load data to SRAM
240
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
241
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
242
+ if not IS_RMS_NORM:
243
+ mean = tl.load(Mean + row)
244
+ if HAS_Z and not NORM_BEFORE_GATE:
245
+ z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)
246
+ x_og = x
247
+ x = x_og * z * tl.sigmoid(z)
248
+ rstd = tl.load(Rstd + row)
249
+ # Compute dx
250
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
251
+ xhat = tl.where(mask, xhat, 0.)
252
+ if HAS_Z and NORM_BEFORE_GATE:
253
+ z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)
254
+ z_sigmoid = tl.sigmoid(z)
255
+ y = xhat * w + b if HAS_BIAS else xhat * w
256
+ if RECOMPUTE_OUTPUT:
257
+ tl.store(Y + cols, y * z * z_sigmoid, mask=mask)
258
+ dz = dy * y * z_sigmoid * (1 + z * (1 - z_sigmoid))
259
+ tl.store(DZ + cols, dz, mask=mask)
260
+ dy *= z * z_sigmoid
261
+ else:
262
+ if RECOMPUTE_OUTPUT:
263
+ y = xhat * w + b if HAS_BIAS else xhat * w
264
+ tl.store(Y + cols, y, mask=mask)
265
+ wdy = w * dy
266
+ c1 = tl.sum(xhat * wdy, axis=0) / N
267
+ if not IS_RMS_NORM:
268
+ c2 = tl.sum(wdy, axis=0) / N
269
+ dx = (wdy - (xhat * c1 + c2)) * rstd
270
+ else:
271
+ dx = (wdy - xhat * c1) * rstd
272
+ dw += dy * xhat
273
+ if HAS_BIAS:
274
+ db += dy
275
+ if HAS_Z and not NORM_BEFORE_GATE:
276
+ z_sigmoid = tl.sigmoid(z)
277
+ dz = dx * x_og * z_sigmoid * (1 + z * (1 - z_sigmoid))
278
+ tl.store(DZ + cols, dz, mask=mask)
279
+ dx *= z * z_sigmoid
280
+ # Write dx
281
+ tl.store(DX + cols, dx, mask=mask)
282
+
283
+ X += stride_x_row
284
+ if HAS_Z:
285
+ Z += stride_z_row
286
+ DZ += stride_dz_row
287
+ if RECOMPUTE_OUTPUT:
288
+ Y += stride_y_row
289
+ DY += stride_dy_row
290
+ DX += stride_dx_row
291
+ tl.store(DW + row_block_id * stride_dw_row + group * N + cols, dw, mask=mask)
292
+ if HAS_BIAS:
293
+ tl.store(DB + row_block_id * stride_db_row + group * N + cols, db, mask=mask)
294
+
295
+
296
+ def layer_norm_bwd(
297
+ dy: torch.Tensor,
298
+ x: torch.Tensor,
299
+ weight: torch.Tensor,
300
+ bias: torch.Tensor,
301
+ eps: float,
302
+ mean: torch.Tensor,
303
+ rstd: torch.Tensor,
304
+ z: torch.Tensor = None,
305
+ group_size: int = None,
306
+ norm_before_gate: bool = True,
307
+ is_rms_norm: bool = False,
308
+ recompute_output: bool = False,
309
+ dz: torch.Tensor = None,
310
+ out: torch.Tensor = None,
311
+ ):
312
+ M, N = x.shape
313
+ if group_size is None:
314
+ group_size = N
315
+ assert N % group_size == 0
316
+ ngroups = N // group_size
317
+ assert x.stride(-1) == 1
318
+ assert dy.stride(-1) == 1
319
+ assert dy.shape == (M, N)
320
+ if z is not None:
321
+ assert z.stride(-1) == 1
322
+ assert z.shape == (M, N)
323
+ assert weight.shape == (N,)
324
+ assert weight.stride(-1) == 1
325
+ if bias is not None:
326
+ assert bias.stride(-1) == 1
327
+ assert bias.shape == (N,)
328
+ # allocate output
329
+ dx = torch.empty_like(x)
330
+ if dz is not None:
331
+ assert z is not None
332
+ assert dz.shape == z.shape
333
+ assert dz.stride(-1) == 1
334
+ else:
335
+ dz = torch.empty_like(z) if z is not None else None
336
+ if recompute_output:
337
+ if out is None:
338
+ out = torch.empty_like(x)
339
+ assert out.shape == x.shape
340
+
341
+ # Less than 64KB per feature: enqueue fused kernel
342
+ MAX_FUSED_SIZE = 65536 // x.element_size()
343
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
344
+ if group_size > BLOCK_N:
345
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
346
+ # heuristics for number of warps
347
+ num_warps = min(max(BLOCK_N // 256, 1), 8)
348
+ sm_count = get_multiprocessor_count(x.device.index)
349
+ # If group size is small (e.g., 64), we're only using 1 warp. So having just 108 programs
350
+ # would limit the occupancy.
351
+ nrow_groups = math.ceil(sm_count * math.ceil(4 / num_warps) / ngroups)
352
+ _dw = torch.empty((nrow_groups, N), dtype=torch.float32, device=weight.device)
353
+ _db = torch.empty((nrow_groups, N), dtype=torch.float32, device=bias.device) if bias is not None else None
354
+ rows_per_program = math.ceil(M / nrow_groups)
355
+ grid = (nrow_groups, ngroups)
356
+ layer_norm_bwd_kernel[grid](
357
+ x,
358
+ weight,
359
+ bias,
360
+ z,
361
+ out if recompute_output else None,
362
+ dy,
363
+ dx,
364
+ _dw,
365
+ _db,
366
+ dz,
367
+ mean,
368
+ rstd,
369
+ x.stride(0),
370
+ z.stride(0) if z is not None else 0,
371
+ 0 if not recompute_output else out.stride(0),
372
+ dy.stride(0),
373
+ dx.stride(0),
374
+ dz.stride(0) if dz is not None else 0,
375
+ _dw.stride(0),
376
+ _db.stride(0) if _db is not None else 0,
377
+ M, group_size, eps,
378
+ rows_per_program,
379
+ BLOCK_N=BLOCK_N,
380
+ NORM_BEFORE_GATE=norm_before_gate,
381
+ IS_RMS_NORM=is_rms_norm,
382
+ num_warps=num_warps
383
+ )
384
+ dw = _dw.sum(0).to(weight.dtype)
385
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
386
+ return (dx, dw, db, dz) if not recompute_output else (dx, dw, db, dz, out)
387
+
388
+
389
+ class LayerNormFn(torch.autograd.Function):
390
+
391
+ @input_guard
392
+ @staticmethod
393
+ def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True,
394
+ is_rms_norm=False):
395
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
396
+ """
397
+
398
+ x_shape_og = x.shape
399
+ # reshape input data into 2D tensor
400
+ x = x.reshape(-1, x.shape[-1])
401
+ if x.stride(-1) != 1:
402
+ x = x.contiguous()
403
+ if z is not None:
404
+ assert z.shape == x_shape_og
405
+ z = z.reshape(-1, z.shape[-1])
406
+ if z.stride(-1) != 1:
407
+ z = z.contiguous()
408
+ weight = weight.contiguous()
409
+ if bias is not None:
410
+ bias = bias.contiguous()
411
+ y, mean, rstd = layer_norm_fwd(
412
+ x,
413
+ weight,
414
+ bias,
415
+ eps,
416
+ z=z,
417
+ group_size=group_size,
418
+ norm_before_gate=norm_before_gate,
419
+ is_rms_norm=is_rms_norm,
420
+ )
421
+ ctx.save_for_backward(x, weight, bias, mean, rstd, z)
422
+ ctx.x_shape_og = x_shape_og
423
+ ctx.eps = eps
424
+ ctx.group_size = group_size
425
+ ctx.norm_before_gate = norm_before_gate
426
+ ctx.is_rms_norm = is_rms_norm
427
+ return y.reshape(x_shape_og)
428
+
429
+ @input_guard
430
+ @staticmethod
431
+ def backward(ctx, dy):
432
+ x, weight, bias, mean, rstd, z = ctx.saved_tensors
433
+ dy = dy.reshape(-1, dy.shape[-1])
434
+ if dy.stride(-1) != 1:
435
+ dy = dy.contiguous()
436
+ assert dy.shape == x.shape
437
+ dx, dw, db, dz = layer_norm_bwd(
438
+ dy,
439
+ x,
440
+ weight,
441
+ bias,
442
+ ctx.eps,
443
+ mean,
444
+ rstd,
445
+ z,
446
+ ctx.group_size,
447
+ ctx.norm_before_gate,
448
+ ctx.is_rms_norm
449
+ )
450
+ dx = dx.reshape(ctx.x_shape_og)
451
+ dz = dz.reshape(ctx.x_shape_og) if dz is not None else None
452
+ return dx, dw, db, dz, None, None, None, None
453
+
454
+
455
+ def layernorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False):
456
+ return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm)
457
+
458
+
459
+ def rmsnorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True):
460
+ return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, True)
461
+
462
+
463
+ class LayerNormGated(nn.Module):
464
+
465
+ def __init__(
466
+ self,
467
+ hidden_size,
468
+ eps: float = 1e-5,
469
+ group_size: Optional[int] = None,
470
+ norm_before_gate: bool = True,
471
+ device: Optional[torch.device] = None,
472
+ dtype: Optional[torch.dtype] = None,
473
+ ):
474
+ """If group_size is not None, we do GroupNorm with each group having group_size elements.
475
+ group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
476
+ """
477
+
478
+ factory_kwargs = {"device": device, "dtype": dtype}
479
+ super().__init__()
480
+ self.eps = eps
481
+ self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
482
+ self.bias = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
483
+ self.group_size = group_size
484
+ self.norm_before_gate = norm_before_gate
485
+ self.reset_parameters()
486
+
487
+ def reset_parameters(self):
488
+ torch.nn.init.ones_(self.weight)
489
+ torch.nn.init.zeros_(self.bias)
490
+
491
+ def forward(self, x, z=None):
492
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
493
+ """
494
+ return layernorm_fn(x, self.weight, self.bias, z=z, group_size=self.group_size, eps=self.eps,
495
+ norm_before_gate=self.norm_before_gate)
496
+
497
+
498
+ class RMSNormGated(nn.Module):
499
+
500
+ def __init__(
501
+ self,
502
+ hidden_size,
503
+ eps: float = 1e-5,
504
+ group_size: Optional[int] = None,
505
+ norm_before_gate: bool = False,
506
+ device: Optional[torch.device] = None,
507
+ dtype: Optional[torch.dtype] = None,
508
+ ):
509
+ """If group_size is not None, we do GroupNorm with each group having group_size elements.
510
+ group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
511
+ """
512
+ factory_kwargs = {"device": device, "dtype": dtype}
513
+ super().__init__()
514
+ self.eps = eps
515
+ self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
516
+ self.register_parameter("bias", None)
517
+ self.group_size = group_size
518
+ self.norm_before_gate = norm_before_gate
519
+ self.reset_parameters()
520
+
521
+ def reset_parameters(self):
522
+ torch.nn.init.ones_(self.weight)
523
+
524
+ def forward(self, x, z=None):
525
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
526
+ """
527
+ return rmsnorm_fn(x, self.weight, self.bias, z=z, eps=self.eps, group_size=self.group_size,
528
+ norm_before_gate=self.norm_before_gate)
fla/modules/mlp.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from functools import partial
7
+ from typing import TYPE_CHECKING, Any, Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.distributed import DeviceMesh
12
+ from torch.distributed.tensor import DTensor, Placement, Replicate, Shard, distribute_module
13
+ from torch.distributed.tensor.parallel import ParallelStyle
14
+
15
+ from fla.modules.activations import swiglu, swiglu_linear
16
+
17
+ if TYPE_CHECKING:
18
+ from transformers.processing_utils import Unpack
19
+
20
+
21
+ class GatedMLP(nn.Module):
22
+
23
+ def __init__(
24
+ self,
25
+ hidden_size: int,
26
+ hidden_ratio: Optional[int] = None,
27
+ intermediate_size: Optional[int] = None,
28
+ hidden_act: str = 'swish',
29
+ fuse_swiglu: bool = True
30
+ ) -> GatedMLP:
31
+ super().__init__()
32
+
33
+ self.hidden_size = hidden_size
34
+ # the final number of params is `hidden_ratio * hidden_size^2`
35
+ # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
36
+ if hidden_ratio is None:
37
+ hidden_ratio = 4
38
+ if intermediate_size is None:
39
+ intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
40
+ intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
41
+ self.hidden_ratio = hidden_ratio
42
+ self.intermediate_size = intermediate_size
43
+ self.hidden_act = hidden_act
44
+ self.fuse_swiglu = fuse_swiglu
45
+
46
+ if hidden_act != 'swish':
47
+ raise ValueError(f'Unsupported hidden_act: {hidden_act}')
48
+
49
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
50
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
51
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
52
+ if self.fuse_swiglu:
53
+ self.swiglu_linear = SwiGLULinear()
54
+
55
+ def forward(
56
+ self,
57
+ x: torch.Tensor,
58
+ **kwargs: Unpack[Any]
59
+ ) -> torch.Tensor:
60
+ gate, y = self.gate_proj(x), self.up_proj(x)
61
+ if self.fuse_swiglu:
62
+ return self.swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
63
+ else:
64
+ return self.down_proj(swiglu(gate, y))
65
+
66
+
67
+ class SwiGLULinear(nn.Module):
68
+
69
+ def forward(self, x, y, weight, bias):
70
+ return swiglu_linear(x, y, weight, bias)
71
+
72
+
73
+ class SwiGLULinearParallel(ParallelStyle):
74
+ def __init__(
75
+ self,
76
+ *,
77
+ input_layouts: Optional[Placement] = None,
78
+ output_layouts: Optional[Placement] = None,
79
+ use_local_output: bool = True,
80
+ ):
81
+ super().__init__()
82
+ self.input_layouts = (input_layouts or Shard(-1),)
83
+ self.output_layouts = (output_layouts or Replicate(),)
84
+ self.desired_input_layouts = (Shard(-1),)
85
+ self.use_local_output = use_local_output
86
+
87
+ @staticmethod
88
+ def _prepare_input_fn(
89
+ input_layouts, desired_input_layouts, mod, inputs, device_mesh
90
+ ):
91
+ x, y, weight, bias = inputs
92
+ if not isinstance(x, DTensor):
93
+ x = DTensor.from_local(x, device_mesh, input_layouts, run_check=False)
94
+ if x.placements != desired_input_layouts:
95
+ x = x.redistribute(placements=desired_input_layouts, async_op=True)
96
+
97
+ if not isinstance(y, DTensor):
98
+ y = DTensor.from_local(y, device_mesh, input_layouts, run_check=False)
99
+ if y.placements != desired_input_layouts:
100
+ y = y.redistribute(placements=desired_input_layouts, async_op=True)
101
+
102
+ if not isinstance(weight, DTensor):
103
+ weight = DTensor.from_local(weight, device_mesh, (Shard(1),))
104
+
105
+ if bias is not None and not isinstance(bias, DTensor):
106
+ bias = DTensor.from_local(bias, device_mesh, (Replicate(),))
107
+
108
+ return x, y, weight, bias
109
+
110
+ @staticmethod
111
+ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
112
+ # Rowwise sharding produces partial output, depending on output layouts:
113
+ # 1. to replicate -> allreduce
114
+ # 2. to shard -> reduce_scatter
115
+ if outputs.placements != output_layouts:
116
+ outputs = outputs.redistribute(placements=output_layouts, async_op=True)
117
+ # back to local tensor if use_local_output is True
118
+ return outputs.to_local() if use_local_output else outputs
119
+
120
+ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
121
+ return distribute_module(
122
+ module,
123
+ device_mesh,
124
+ partition_fn=None,
125
+ input_fn=partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts),
126
+ output_fn=partial(self._prepare_output_fn, self.output_layouts, self.use_local_output)
127
+ )
fla/ops/attn/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .parallel import parallel_attn
4
+ from .parallel_rectified import parallel_rectified_attn
5
+ from .parallel_softpick import parallel_softpick_attn
6
+ from .naive import naive_attn
7
+ from .naive_rectified import naive_rectified_attn
8
+ from .naive_softpick import naive_softpick_attn
9
+
10
+ __all__ = [
11
+ 'parallel_attn',
12
+ 'parallel_rectified_attn',
13
+ 'parallel_softpick_attn',
14
+ 'naive_attn',
15
+ 'naive_rectified_attn',
16
+ 'naive_softpick_attn',
17
+ ]
fla/ops/attn/__pycache__/parallel_softpick.cpython-312.pyc ADDED
Binary file (34.9 kB). View file
 
fla/ops/attn/naive_softpick.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from typing import Optional
4
+ from einops import rearrange
5
+
6
+ def softpick(x, dim=-1, eps=1e-8):
7
+ # softpick function: relu(exp(x)-1) / sum(abs(exp(x)-1))
8
+ # numerically stable version
9
+ x_m = torch.max(x, dim=dim, keepdim=True).values
10
+ x_m_e_m = torch.exp(-x_m)
11
+ x_e_1 = torch.exp(x - x_m) - x_m_e_m
12
+ r_x_e_1 = F.relu(x_e_1)
13
+ a_x_e_1 = torch.where(x.isfinite(), torch.abs(x_e_1), 0)
14
+ return r_x_e_1 / (torch.sum(a_x_e_1, dim=dim, keepdim=True) + eps) # epsilon is only useful if all inputs are EXACTLY 0. we might not even need it
15
+
16
+ def naive_softpick_attn(
17
+ q: torch.Tensor,
18
+ k: torch.Tensor,
19
+ v: torch.Tensor,
20
+ scale: Optional[float] = None,
21
+ cu_seqlens: Optional[torch.LongTensor] = None,
22
+ head_first: bool = False
23
+ ) -> torch.Tensor:
24
+ head_dim = q.shape[-1]
25
+ if scale is None:
26
+ scale = 1.0 / (head_dim ** 0.5)
27
+ if not head_first:
28
+ q, k, v = map(lambda x: rearrange(x, 'b t h d -> b h t d'), (q, k, v))
29
+ q_len = q.shape[-2]
30
+ k_len = k.shape[-2]
31
+ mask = torch.tril(torch.ones(k_len, k_len, device=q.device))
32
+ wei = torch.matmul(q, k.transpose(2, 3)) # shape: (batch_size, num_heads, q_len, k_len)
33
+ wei = wei * scale
34
+ wei = wei.masked_fill(mask[k_len-q_len:k_len, :k_len] == 0, float('-inf'))
35
+ wei = softpick(wei.float(), dim=-1).to(q.dtype)
36
+ o = torch.matmul(wei, v) # shape: (batch_size, num_heads, q_len, head_dim)
37
+ if not head_first:
38
+ o = rearrange(o, 'b h t d -> b t h d')
39
+ return o, wei
fla/ops/based/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .fused_chunk import fused_chunk_based
4
+ from .parallel import parallel_based
5
+
6
+ __all__ = [
7
+ 'fused_chunk_based',
8
+ 'parallel_based'
9
+ ]
fla/ops/common/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
fla/ops/common/chunk_delta_h.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_offsets
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import check_shared_mem, is_nvidia_hopper, use_cuda_graph
13
+
14
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_G': lambda args: args['g'] is not None,
19
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
20
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
21
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
22
+ })
23
+ @triton.autotune(
24
+ configs=[
25
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
26
+ for num_warps in NUM_WARPS
27
+ for num_stages in [2, 3, 4]
28
+ ],
29
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'],
30
+ use_cuda_graph=use_cuda_graph,
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def chunk_gated_delta_rule_fwd_kernel_h(
34
+ k,
35
+ v,
36
+ d,
37
+ v_new,
38
+ g,
39
+ h,
40
+ h0,
41
+ ht,
42
+ offsets,
43
+ chunk_offsets,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BC: tl.constexpr,
50
+ BK: tl.constexpr,
51
+ BV: tl.constexpr,
52
+ NT: tl.constexpr,
53
+ USE_G: tl.constexpr,
54
+ USE_INITIAL_STATE: tl.constexpr,
55
+ STORE_FINAL_STATE: tl.constexpr,
56
+ USE_OFFSETS: tl.constexpr,
57
+ HEAD_FIRST: tl.constexpr,
58
+ ):
59
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
60
+ i_n, i_h = i_nh // H, i_nh % H
61
+ if USE_OFFSETS:
62
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
63
+ T = eos - bos
64
+ NT = tl.cdiv(T, BT)
65
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
66
+ else:
67
+ bos, eos = i_n * T, i_n * T + T
68
+ NT = tl.cdiv(T, BT)
69
+ boh = i_n * NT
70
+
71
+ # [BK, BV]
72
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
73
+ if USE_INITIAL_STATE:
74
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
75
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
76
+
77
+ for i_t in range(NT):
78
+ if HEAD_FIRST:
79
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ else:
81
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
82
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
83
+ b_hc = tl.zeros([BK, BV], dtype=tl.float32)
84
+ if USE_G:
85
+ last_idx = min((i_t + 1) * BT, T) - 1
86
+ if HEAD_FIRST:
87
+ b_g_last = tl.load(g + i_nh * T + last_idx)
88
+ else:
89
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
90
+ else:
91
+ b_g_last = None
92
+ last_idx = None
93
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
94
+ for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
95
+ if HEAD_FIRST:
96
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
97
+ p_d = tl.make_block_ptr(d + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
98
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
99
+ p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
100
+ p_g = tl.make_block_ptr(g + i_nh * T, (T,), (1,), (i_t * BT + i_c * BC,), (BC,), (0,)) if USE_G else None
101
+ else:
102
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
103
+ p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
104
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
105
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
106
+ p_g = tl.make_block_ptr(g+bos*H+i_h, (T,), (H,), (i_t*BT+i_c*BC, ), (BC,), (0,)) if USE_G else None
107
+ b_g = tl.load(p_g, boundary_check=(0, )) if USE_G else None
108
+ # [BK, BC]
109
+ b_k = tl.load(p_k, boundary_check=(0, 1))
110
+ b_k = (b_k * exp(b_g_last - b_g)[None, :]).to(b_k.dtype) if USE_G else b_k
111
+ # [BC, BK]
112
+ b_d = tl.load(p_d, boundary_check=(0, 1))
113
+ b_d = (b_d * exp(b_g)[:, None]).to(b_d.dtype) if USE_G else b_d
114
+ # [BC, BV]
115
+ b_v = tl.load(p_v, boundary_check=(0, 1))
116
+ b_v2 = b_v - tl.dot(b_d, b_h.to(b_d.dtype))
117
+ # [BK, BV]
118
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
119
+ b_hc += tl.dot(b_k, b_v2.to(b_k.dtype), allow_tf32=False)
120
+ b_h *= exp(b_g_last) if USE_G else 1
121
+ b_h += b_hc
122
+
123
+ if STORE_FINAL_STATE:
124
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
125
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
126
+
127
+
128
+ @triton.heuristics({
129
+ 'USE_G': lambda args: args['g'] is not None,
130
+ 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
131
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
132
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
133
+ })
134
+ @triton.autotune(
135
+ configs=[
136
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
137
+ for num_warps in NUM_WARPS
138
+ for num_stages in [2, 3, 4]
139
+ ],
140
+ key=['BT', 'BK', 'BV', 'USE_G'],
141
+ use_cuda_graph=use_cuda_graph,
142
+ )
143
+ @triton.jit(do_not_specialize=['T'])
144
+ def chunk_gated_delta_rule_bwd_kernel_dhu(
145
+ q,
146
+ k,
147
+ d,
148
+ g,
149
+ dht,
150
+ dh0,
151
+ do,
152
+ dh,
153
+ dv,
154
+ dv2,
155
+ offsets,
156
+ chunk_offsets,
157
+ scale,
158
+ T,
159
+ H: tl.constexpr,
160
+ K: tl.constexpr,
161
+ V: tl.constexpr,
162
+ BT: tl.constexpr,
163
+ BC: tl.constexpr,
164
+ BK: tl.constexpr,
165
+ BV: tl.constexpr,
166
+ USE_G: tl.constexpr,
167
+ USE_INITIAL_STATE: tl.constexpr,
168
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
169
+ USE_OFFSETS: tl.constexpr,
170
+ HEAD_FIRST: tl.constexpr
171
+ ):
172
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
173
+ i_n, i_h = i_nh // H, i_nh % H
174
+ if USE_OFFSETS:
175
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
176
+ T = eos - bos
177
+ NT = tl.cdiv(T, BT)
178
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
179
+ else:
180
+ bos, eos = i_n * T, i_n * T + T
181
+ NT = tl.cdiv(T, BT)
182
+ boh = i_n * NT
183
+
184
+ # [BK, BV]
185
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
186
+ if USE_FINAL_STATE_GRADIENT:
187
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
188
+ b_dh += tl.load(p_dht, boundary_check=(0, 1))
189
+
190
+ for i_t in range(NT - 1, -1, -1):
191
+ if HEAD_FIRST:
192
+ p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
193
+ else:
194
+ p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
195
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
196
+ b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
197
+ if USE_G:
198
+ last_idx = min((i_t + 1) * BT, T) - 1
199
+ if HEAD_FIRST:
200
+ bg_last = tl.load(g + i_nh * T + last_idx)
201
+ else:
202
+ bg_last = tl.load(g + (bos + last_idx) * H + i_h)
203
+ else:
204
+ bg_last = None
205
+ last_idx = None
206
+ for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
207
+ if HEAD_FIRST:
208
+ p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
209
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
210
+ p_d = tl.make_block_ptr(d + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
211
+ p_dv = tl.make_block_ptr(dv + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
212
+ p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
213
+ p_g = tl.make_block_ptr(g + i_nh * T, (T,), (1,), (i_t * BT + i_c * BC,), (BC,), (0,)) if USE_G else None
214
+ p_dv2 = tl.make_block_ptr(dv2 + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
215
+ else:
216
+ p_q = tl.make_block_ptr(q+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
217
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
218
+ p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
219
+ p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
220
+ p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
221
+ p_g = tl.make_block_ptr(g+bos*H+i_h, (T,), (H,), (i_t*BT + i_c * BC,), (BC,), (0,)) if USE_G else None
222
+ p_dv2 = tl.make_block_ptr(dv2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
223
+ b_g = tl.load(p_g, boundary_check=(0,)) if USE_G else None
224
+ # [BK, BT]
225
+ b_q = tl.load(p_q, boundary_check=(0, 1))
226
+ b_q = (b_q * scale * exp(b_g)[None, :]).to(b_q.dtype) if USE_G else (b_q * scale).to(b_q.dtype)
227
+ # [BT, BK]
228
+ b_k = tl.load(p_k, boundary_check=(0, 1))
229
+ b_d = tl.load(p_d, boundary_check=(0, 1))
230
+ b_k = (b_k * exp(bg_last - b_g)[:, None]).to(b_k.dtype) if USE_G else b_k
231
+ b_d = (b_d * exp(b_g)[None, :]).to(b_d.dtype) if USE_G else b_d
232
+ # [BT, V]
233
+ b_do = tl.load(p_do, boundary_check=(0, 1))
234
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
235
+ b_dv2 = b_dv + tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
236
+ tl.store(p_dv2, b_dv2.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
237
+ # [BK, BV]
238
+ b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)
239
+ b_dh_tmp -= tl.dot(b_d, b_dv2.to(b_q.dtype), allow_tf32=False)
240
+ b_dh *= exp(bg_last) if USE_G else 1
241
+ b_dh += b_dh_tmp
242
+
243
+ if USE_INITIAL_STATE:
244
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
245
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
246
+
247
+
248
+ def chunk_gated_delta_rule_fwd_h(
249
+ k: torch.Tensor,
250
+ w: torch.Tensor,
251
+ u: torch.Tensor,
252
+ g: Optional[torch.Tensor] = None,
253
+ initial_state: Optional[torch.Tensor] = None,
254
+ output_final_state: bool = False,
255
+ offsets: Optional[torch.LongTensor] = None,
256
+ indices: Optional[torch.LongTensor] = None,
257
+ head_first: bool = True,
258
+ chunk_size: int = 64
259
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
260
+ if head_first:
261
+ B, H, T, K, V = *k.shape, u.shape[-1]
262
+ else:
263
+ B, T, H, K, V = *k.shape, u.shape[-1]
264
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
265
+ # N: the actual number of sequences in the batch with either equal or variable lengths
266
+ if offsets is None:
267
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
268
+ else:
269
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
270
+ BK = triton.next_power_of_2(K)
271
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
272
+ # H100 can have larger block size
273
+ if check_shared_mem('hopper', k.device.index):
274
+ BV = 64
275
+ BC = 64 if K <= 128 else 32
276
+ # A100
277
+ elif check_shared_mem('ampere', k.device.index):
278
+ BV = 32
279
+ BC = 64
280
+ else:
281
+ BV = 32
282
+ BC = 32 if K <= 128 else 16
283
+ BC = min(BT, BC)
284
+ NK = triton.cdiv(K, BK)
285
+ NV = triton.cdiv(V, BV)
286
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
287
+
288
+ if head_first:
289
+ h = k.new_empty(B, H, NT, K, V)
290
+ else:
291
+ h = k.new_empty(B, NT, H, K, V)
292
+ final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
293
+
294
+ v_new = torch.empty_like(u)
295
+ grid = (NK, NV, N * H)
296
+
297
+ chunk_gated_delta_rule_fwd_kernel_h[grid](
298
+ k=k,
299
+ v=u,
300
+ d=w,
301
+ v_new=v_new,
302
+ g=g,
303
+ h=h,
304
+ h0=initial_state,
305
+ ht=final_state,
306
+ offsets=offsets,
307
+ chunk_offsets=chunk_offsets,
308
+ T=T,
309
+ H=H,
310
+ K=K,
311
+ V=V,
312
+ BT=BT,
313
+ BC=BC,
314
+ BK=BK,
315
+ BV=BV,
316
+ NT=NT,
317
+ HEAD_FIRST=head_first
318
+ )
319
+ return h, v_new, final_state
320
+
321
+
322
+ def chunk_gated_delta_rule_bwd_dhu(
323
+ q: torch.Tensor,
324
+ k: torch.Tensor,
325
+ w: torch.Tensor,
326
+ g: torch.Tensor,
327
+ h0: torch.Tensor,
328
+ dht: Optional[torch.Tensor],
329
+ do: torch.Tensor,
330
+ dv: torch.Tensor,
331
+ scale: float,
332
+ offsets: Optional[torch.LongTensor] = None,
333
+ indices: Optional[torch.LongTensor] = None,
334
+ head_first: bool = True,
335
+ chunk_size: int = 64
336
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
337
+ if head_first:
338
+ B, H, T, K, V = *q.shape, do.shape[-1]
339
+ else:
340
+ B, T, H, K, V = *q.shape, do.shape[-1]
341
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
342
+ # N: the actual number of sequences in the batch with either equal or variable lengths
343
+ if offsets is None:
344
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
345
+ else:
346
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
347
+
348
+ BK = triton.next_power_of_2(K)
349
+ assert BK <= 256, "current kernel does not support head dimension being larger than 256."
350
+
351
+ # H100
352
+ if check_shared_mem('hopper', q.device.index):
353
+ BV = 64
354
+ BC = 64 if K <= 128 else 32
355
+ # A100
356
+ elif check_shared_mem('ampere', q.device.index):
357
+ BV = 32
358
+ BC = 64 if K <= 128 else 32
359
+ else:
360
+ BV = 32 if K <= 128 else 16
361
+ BC = 16
362
+
363
+ BC = min(BT, BC)
364
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
365
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
366
+
367
+ if head_first:
368
+ dh = q.new_empty(B, H, NT, K, V)
369
+ else:
370
+ dh = q.new_empty(B, NT, H, K, V)
371
+ dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
372
+ dv2 = torch.empty_like(dv)
373
+
374
+ grid = (NK, NV, N * H)
375
+ chunk_gated_delta_rule_bwd_kernel_dhu[grid](
376
+ q=q,
377
+ k=k,
378
+ d=w,
379
+ g=g,
380
+ dht=dht,
381
+ dh0=dh0,
382
+ do=do,
383
+ dh=dh,
384
+ dv=dv,
385
+ dv2=dv2,
386
+ offsets=offsets,
387
+ chunk_offsets=chunk_offsets,
388
+ scale=scale,
389
+ T=T,
390
+ H=H,
391
+ K=K,
392
+ V=V,
393
+ BT=BT,
394
+ BC=BC,
395
+ BK=BK,
396
+ BV=BV,
397
+ HEAD_FIRST=head_first
398
+ )
399
+ return dh, dh0, dv2
fla/ops/common/chunk_h.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_offsets
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import check_shared_mem
13
+
14
+ BKV_LIST = [32, 64] if check_shared_mem() else [16, 32]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
19
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
20
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
21
+ })
22
+ @triton.autotune(
23
+ configs=[
24
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
25
+ for BK in BKV_LIST
26
+ for BV in BKV_LIST
27
+ for num_warps in [1, 2, 4, 8]
28
+ for num_stages in [2, 3, 4]
29
+ ],
30
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def chunk_fwd_kernel_h(
34
+ k,
35
+ v,
36
+ h,
37
+ g,
38
+ gk,
39
+ gv,
40
+ h0,
41
+ ht,
42
+ offsets,
43
+ split_offsets,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BS: tl.constexpr,
50
+ BK: tl.constexpr,
51
+ BV: tl.constexpr,
52
+ USE_G: tl.constexpr,
53
+ USE_GK: tl.constexpr,
54
+ USE_GV: tl.constexpr,
55
+ USE_INITIAL_STATE: tl.constexpr,
56
+ STORE_FINAL_STATE: tl.constexpr,
57
+ USE_OFFSETS: tl.constexpr,
58
+ HEAD_FIRST: tl.constexpr
59
+ ):
60
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
61
+ i_n, i_h = i_nh // H, i_nh % H
62
+ if USE_OFFSETS:
63
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
64
+ T = eos - bos
65
+ NT = tl.cdiv(T, BT)
66
+ NS = tl.cdiv(T, BS)
67
+ boh = tl.load(split_offsets + i_n).to(tl.int32)
68
+ else:
69
+ bos, eos = i_n * T, i_n * T + T
70
+ NT = tl.cdiv(T, BT)
71
+ NS = tl.cdiv(T, BS)
72
+ boh = i_n * NS
73
+
74
+ # [BK, BV]
75
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
76
+ if USE_INITIAL_STATE:
77
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
78
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
79
+
80
+ for i_t in range(NT):
81
+ i_s = i_t // (BS // BT)
82
+ if HEAD_FIRST:
83
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
84
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
85
+
86
+ o_h = (i_nh * NS + i_s).to(tl.int64) * K*V
87
+ p_h = tl.make_block_ptr(h + o_h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
88
+ else:
89
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
90
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
91
+
92
+ o_h = ((boh + i_s) * H + i_h).to(tl.int64) * K*V
93
+ p_h = tl.make_block_ptr(h + o_h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
94
+
95
+ if i_t % (BS // BT) == 0:
96
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
97
+ # [BK, BT]
98
+ b_k = tl.load(p_k, boundary_check=(0, 1))
99
+ # [BT, BV]
100
+ b_v = tl.load(p_v, boundary_check=(0, 1))
101
+ last_idx = min((i_t + 1) * BT, T) - 1
102
+
103
+ # scalar decay
104
+ if USE_G:
105
+ if HEAD_FIRST:
106
+ b_g_last = tl.load(g + i_nh * T + last_idx)
107
+ p_g = g + i_nh * T + i_t * BT + tl.arange(0, BT)
108
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
109
+ else:
110
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
111
+ p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h
112
+ b_h *= exp(b_g_last)
113
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
114
+ b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype)
115
+
116
+ # vector decay, h = Diag(gk) @ h
117
+ if USE_GK:
118
+ if HEAD_FIRST:
119
+ p_gk = tl.make_block_ptr(gk + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
120
+ p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
121
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
122
+ else:
123
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
124
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
125
+
126
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
127
+ b_h *= exp(b_gk_last)[:, None]
128
+
129
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
130
+ b_k = (b_k * exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype)
131
+
132
+ # vector decay, h = h @ Diag(gv)
133
+ if USE_GV:
134
+ if HEAD_FIRST:
135
+ p_gv = tl.make_block_ptr(gv + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
136
+ p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
137
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
138
+ else:
139
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
140
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
141
+
142
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
143
+ b_h *= exp(b_gv_last)[None, :]
144
+
145
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
146
+ b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype)
147
+
148
+ b_h += tl.dot(b_k, b_v)
149
+
150
+ if STORE_FINAL_STATE:
151
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
152
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
153
+
154
+
155
+ @triton.heuristics({
156
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
157
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
158
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
159
+ })
160
+ @triton.autotune(
161
+ configs=[
162
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
163
+ for BK in BKV_LIST
164
+ for BV in BKV_LIST
165
+ for num_warps in [1, 2, 4, 8]
166
+ for num_stages in [2, 3, 4]
167
+ ],
168
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
169
+ )
170
+ @triton.jit(do_not_specialize=['T'])
171
+ def chunk_bwd_kernel_dh(
172
+ q,
173
+ g,
174
+ gk,
175
+ gv,
176
+ do,
177
+ dh,
178
+ dht,
179
+ dh0,
180
+ offsets,
181
+ split_offsets,
182
+ scale,
183
+ T,
184
+ HQ: tl.constexpr,
185
+ H: tl.constexpr,
186
+ K: tl.constexpr,
187
+ V: tl.constexpr,
188
+ BT: tl.constexpr,
189
+ BS: tl.constexpr,
190
+ BK: tl.constexpr,
191
+ BV: tl.constexpr,
192
+ NG: tl.constexpr,
193
+ USE_G: tl.constexpr,
194
+ USE_GK: tl.constexpr,
195
+ USE_GV: tl.constexpr,
196
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
197
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
198
+ USE_OFFSETS: tl.constexpr,
199
+ HEAD_FIRST: tl.constexpr
200
+ ):
201
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
202
+ i_bg = i_nh // NG
203
+ i_n, i_hq = i_nh // HQ, i_nh % HQ
204
+ i_h = i_hq // NG
205
+ if USE_OFFSETS:
206
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
207
+ T = eos - bos
208
+ NT = tl.cdiv(T, BT)
209
+ NS = tl.cdiv(T, BS)
210
+ boh = tl.load(split_offsets + i_n).to(tl.int32)
211
+ else:
212
+ bos, eos = i_n * T, i_n * T + T
213
+ NT = tl.cdiv(T, BT)
214
+ NS = tl.cdiv(T, BS)
215
+ boh = i_n * NS
216
+
217
+ # [BK, BV]
218
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
219
+ if USE_FINAL_STATE_GRADIENT:
220
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
221
+ b_dh += tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32)
222
+
223
+ for i_t in range(NT - 1, -1, -1):
224
+ i_s = i_t // (BS // BT)
225
+ if HEAD_FIRST:
226
+ o_dh = (i_nh * NS + i_s).to(tl.int64) * K*V
227
+ p_dh = tl.make_block_ptr(dh + o_dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
228
+ else:
229
+ o_dh = ((boh + i_s) * H + i_h).to(tl.int64) * K*V
230
+ p_dh = tl.make_block_ptr(dh + o_dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
231
+
232
+ if i_t % (BS // BT) == 0:
233
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
234
+ last_idx = min(i_t * BT + BT, T) - 1
235
+ # [BK, BT]
236
+ if HEAD_FIRST:
237
+ p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
238
+ p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
239
+ else:
240
+ p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
241
+ p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
242
+ b_q = tl.load(p_q, boundary_check=(0, 1))
243
+ b_q = (b_q * scale).to(b_q.dtype)
244
+ # [BT, BV]
245
+ b_do = tl.load(p_do, boundary_check=(0, 1))
246
+
247
+ if USE_G:
248
+ if HEAD_FIRST:
249
+ p_g = g + i_bg * T + i_t * BT + tl.arange(0, BT)
250
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
251
+ b_g_last = tl.load(g + i_bg * T + last_idx)
252
+ else:
253
+ p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h
254
+ b_g_last = tl.load(g + (bos + last_idx) * H + i_h)
255
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
256
+ b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype)
257
+
258
+ b_dh *= exp(b_g_last)
259
+
260
+ if USE_GK:
261
+ if HEAD_FIRST:
262
+ p_gk = tl.make_block_ptr(gk + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
263
+ p_gk_last = gk + (i_bg * T + last_idx) * K + i_k * BK + tl.arange(0, BK)
264
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
265
+ else:
266
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
267
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
268
+
269
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
270
+ b_q = (b_q * exp(b_gk)).to(b_q.dtype)
271
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
272
+ b_dh *= exp(b_gk_last)[:, None]
273
+
274
+ if USE_GV:
275
+ if HEAD_FIRST:
276
+ p_gv = tl.make_block_ptr(gv + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
277
+ p_gv_last = gv + (i_bg * T + last_idx) * V + i_v * BV + tl.arange(0, BV)
278
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
279
+ else:
280
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
281
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
282
+
283
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
284
+ b_do = (b_do * exp(b_gv)).to(b_do.dtype)
285
+
286
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
287
+ b_dh *= exp(b_gv_last)[None, :]
288
+
289
+ b_dh += tl.dot(b_q, b_do)
290
+
291
+ if STORE_INITIAL_STATE_GRADIENT:
292
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
293
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
294
+
295
+
296
+ def chunk_fwd_h(
297
+ k: torch.Tensor,
298
+ v: torch.Tensor,
299
+ g: torch.Tensor,
300
+ gk: torch.Tensor,
301
+ gv: torch.Tensor,
302
+ h0: torch.Tensor,
303
+ output_final_state: bool,
304
+ offsets: Optional[torch.Tensor] = None,
305
+ head_first: bool = True,
306
+ chunk_size: int = 64,
307
+ split_size: Optional[int] = None,
308
+ states_in_fp32: bool = False
309
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
310
+ if head_first:
311
+ B, H, T, K, V = *k.shape, v.shape[-1]
312
+ else:
313
+ B, T, H, K, V = *k.shape, v.shape[-1]
314
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
315
+ BS = BT if split_size is None else min(split_size, max(16, triton.next_power_of_2(T)))
316
+ assert BS % BT == 0, f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}"
317
+ # N: the actual number of sequences in the batch with either equal or variable lengths
318
+ if offsets is None:
319
+ split_offsets, N, NS = None, B, triton.cdiv(T, BS)
320
+ else:
321
+ split_offsets = prepare_chunk_offsets(offsets, BS)
322
+ N, NS = len(offsets) - 1, split_offsets[-1]
323
+
324
+ if head_first:
325
+ h = k.new_empty(B, H, NS, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
326
+ else:
327
+ h = k.new_empty(B, NS, H, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
328
+ ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None
329
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H)
330
+ chunk_fwd_kernel_h[grid](
331
+ k=k,
332
+ v=v,
333
+ h=h,
334
+ g=g,
335
+ gk=gk,
336
+ gv=gv,
337
+ h0=h0,
338
+ ht=ht,
339
+ offsets=offsets,
340
+ split_offsets=split_offsets,
341
+ T=T,
342
+ H=H,
343
+ K=K,
344
+ V=V,
345
+ BT=BT,
346
+ BS=BS,
347
+ USE_G=g is not None,
348
+ USE_GK=gk is not None,
349
+ USE_GV=gv is not None,
350
+ HEAD_FIRST=head_first
351
+ )
352
+ return h, ht
353
+
354
+
355
+ def chunk_bwd_dh(
356
+ q: torch.Tensor,
357
+ k: torch.Tensor,
358
+ v: torch.Tensor,
359
+ g: torch.Tensor,
360
+ gk: torch.Tensor,
361
+ gv: torch.Tensor,
362
+ do: torch.Tensor,
363
+ h0: torch.Tensor,
364
+ dht: torch.Tensor,
365
+ scale: float,
366
+ offsets: Optional[torch.Tensor] = None,
367
+ head_first: bool = True,
368
+ chunk_size: int = 64,
369
+ split_size: Optional[int] = None,
370
+ states_in_fp32: bool = False
371
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
372
+ if head_first:
373
+ B, H, T, K, V = *k.shape, v.shape[-1]
374
+ HQ = q.shape[1]
375
+ else:
376
+ B, T, H, K, V = *k.shape, v.shape[-1]
377
+ HQ = q.shape[2]
378
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
379
+ BS = BT if split_size is None else min(split_size, max(16, triton.next_power_of_2(T)))
380
+ assert BS % BT == 0, f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}"
381
+ # N: the actual number of sequences in the batch with either equal or variable lengths
382
+ # NG: number of groups in GQA
383
+ if offsets is None:
384
+ split_offsets, N, NS = None, B, triton.cdiv(T, BS)
385
+ else:
386
+ split_offsets = prepare_chunk_offsets(offsets, BS)
387
+ N, NS = len(offsets) - 1, split_offsets[-1]
388
+ NG = HQ // H
389
+
390
+ if head_first:
391
+ dh = k.new_empty(B, HQ, NS, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
392
+ else:
393
+ dh = k.new_empty(B, NS, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
394
+ dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None
395
+
396
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H)
397
+ chunk_bwd_kernel_dh[grid](
398
+ q=q,
399
+ g=g,
400
+ gk=gk,
401
+ gv=gv,
402
+ do=do,
403
+ dh=dh,
404
+ dht=dht,
405
+ dh0=dh0,
406
+ offsets=offsets,
407
+ split_offsets=split_offsets,
408
+ scale=scale,
409
+ T=T,
410
+ HQ=HQ,
411
+ H=H,
412
+ K=K,
413
+ V=V,
414
+ BT=BT,
415
+ BS=BS,
416
+ NG=NG,
417
+ USE_G=g is not None,
418
+ USE_GK=gk is not None,
419
+ USE_GV=gv is not None,
420
+ HEAD_FIRST=head_first
421
+ )
422
+ return dh, dh0
fla/ops/common/chunk_h_split.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp
11
+
12
+
13
+ @triton.heuristics({
14
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
15
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
16
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
21
+ for BK in [32, 64]
22
+ for BV in [32, 64]
23
+ for num_warps in [2, 4, 8]
24
+ for num_stages in [2, 3]
25
+ ],
26
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV'],
27
+ )
28
+ @triton.jit(do_not_specialize=['T'])
29
+ def chunk_fwd_kernel_h_split(
30
+ k,
31
+ v,
32
+ g,
33
+ gk,
34
+ gv,
35
+ hs,
36
+ hr,
37
+ h0,
38
+ ht,
39
+ offsets,
40
+ split_indices,
41
+ T,
42
+ S: tl.constexpr,
43
+ H: tl.constexpr,
44
+ K: tl.constexpr,
45
+ V: tl.constexpr,
46
+ BT: tl.constexpr,
47
+ BK: tl.constexpr,
48
+ BV: tl.constexpr,
49
+ USE_G: tl.constexpr,
50
+ USE_GK: tl.constexpr,
51
+ USE_GV: tl.constexpr,
52
+ USE_INITIAL_STATE: tl.constexpr,
53
+ STORE_FINAL_STATE: tl.constexpr,
54
+ USE_OFFSETS: tl.constexpr,
55
+ HEAD_FIRST: tl.constexpr
56
+ ):
57
+ # handle one split at a time
58
+ # i_h: head index
59
+ # i_n: sequence index
60
+ # i_s: local split index inside a sequence
61
+ i_k, i_v, i_sh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
62
+ i_ss, i_h = i_sh // H, i_sh % H
63
+ if USE_OFFSETS:
64
+ i_n, i_s = tl.load(split_indices + i_ss * 2).to(tl.int32), tl.load(split_indices + i_ss * 2 + 1).to(tl.int32)
65
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
66
+ T = eos - bos
67
+ NS = tl.cdiv(T, S)
68
+ else:
69
+ NS = tl.cdiv(T, S)
70
+ i_n, i_s = i_ss // NS, i_ss % NS
71
+ bos, eos = i_n * T, i_n * T + T
72
+ i_nh = i_n * H + i_h
73
+
74
+ # [BK, BV]
75
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
76
+ # for the first split, we directly store the state as the final result
77
+ if i_s == 0:
78
+ if USE_INITIAL_STATE:
79
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ b_h += tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
81
+ p_hr = tl.make_block_ptr(hr + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
82
+ tl.store(p_hr, b_h.to(p_hr.dtype.element_ty), boundary_check=(0, 1))
83
+ for i_t in range(tl.cdiv(i_s * S, BT), tl.cdiv(min(i_s * S + S, T), BT)):
84
+ if HEAD_FIRST:
85
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
86
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
87
+ else:
88
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
89
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
90
+ # [BK, BT]
91
+ b_k = tl.load(p_k, boundary_check=(0, 1))
92
+ # [BT, BV]
93
+ b_v = tl.load(p_v, boundary_check=(0, 1))
94
+ last_idx = min(i_t * BT + BT, T) - 1
95
+
96
+ # scalar decay
97
+ if USE_G:
98
+ if HEAD_FIRST:
99
+ b_g_last = tl.load(g + i_nh * T + last_idx)
100
+ p_g = g + i_nh * T + i_t * BT + tl.arange(0, BT)
101
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
102
+ else:
103
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
104
+ p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h
105
+ b_h *= exp(b_g_last)
106
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
107
+ b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype)
108
+
109
+ # vector decay, h = Diag(gk) @ h
110
+ if USE_GK:
111
+ if HEAD_FIRST:
112
+ p_gk = tl.make_block_ptr(gk + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
113
+ p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
114
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
115
+ else:
116
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
117
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
118
+
119
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
120
+ b_h *= exp(b_gk_last)[:, None]
121
+
122
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
123
+ b_k = (b_k * exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype)
124
+
125
+ # vector decay, h = h @ Diag(gv)
126
+ if USE_GV:
127
+ if HEAD_FIRST:
128
+ p_gv = tl.make_block_ptr(gv + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
129
+ p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
130
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
131
+ else:
132
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
133
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
134
+
135
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
136
+ b_h *= exp(b_gv_last)[None, :]
137
+
138
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
139
+ b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype)
140
+
141
+ b_h += tl.dot(b_k, b_v)
142
+
143
+ # if there are more than one splits, we store the result to (unreduced) hs
144
+ # otherwise, we store the result to ht as the final state
145
+ if NS > 1:
146
+ p_hs = tl.make_block_ptr(hs + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
147
+ tl.store(p_hs, b_h.to(p_hs.dtype.element_ty), boundary_check=(0, 1))
148
+ elif STORE_FINAL_STATE:
149
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
150
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
151
+
152
+
153
+ @triton.heuristics({
154
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
155
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
156
+ })
157
+ @triton.autotune(
158
+ configs=[
159
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
160
+ for BK in [32, 64]
161
+ for BV in [32, 64]
162
+ for num_warps in [2, 4, 8]
163
+ for num_stages in [2, 3, 4]
164
+ ],
165
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV'],
166
+ )
167
+ @triton.jit(do_not_specialize=['T'])
168
+ def chunk_fwd_kernel_h_reduction(
169
+ g,
170
+ gk,
171
+ gv,
172
+ hs,
173
+ hr,
174
+ ht,
175
+ offsets,
176
+ split_offsets,
177
+ T,
178
+ S: tl.constexpr,
179
+ H: tl.constexpr,
180
+ K: tl.constexpr,
181
+ V: tl.constexpr,
182
+ BT: tl.constexpr,
183
+ BK: tl.constexpr,
184
+ BV: tl.constexpr,
185
+ USE_G: tl.constexpr,
186
+ USE_GK: tl.constexpr,
187
+ USE_GV: tl.constexpr,
188
+ STORE_FINAL_STATE: tl.constexpr,
189
+ USE_OFFSETS: tl.constexpr,
190
+ HEAD_FIRST: tl.constexpr
191
+ ):
192
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
193
+ i_n, i_h = i_nh // H, i_nh % H
194
+ if USE_OFFSETS:
195
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
196
+ T = eos - bos
197
+ NS = tl.cdiv(T, S)
198
+ boh = tl.load(split_offsets + i_n).to(tl.int32)
199
+ else:
200
+ bos, eos = i_n * T, i_n * T + T
201
+ NS = tl.cdiv(T, S)
202
+ boh = i_n * NS
203
+
204
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
205
+ # skip the first split
206
+ for i_s in range(1, NS):
207
+ p_hs = tl.make_block_ptr(hs + ((boh + i_s-1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
208
+ p_hr = tl.make_block_ptr(hr + ((boh + i_s) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
209
+ b_h += tl.load(p_hs, boundary_check=(0, 1)).to(tl.float32)
210
+ tl.store(p_hr, b_h.to(p_hr.dtype.element_ty), boundary_check=(0, 1))
211
+
212
+ for i_t in range(tl.cdiv(i_s * S, BT), tl.cdiv(min(i_s * S + S, T), BT)):
213
+ last_idx = min(i_t * BT + BT, T) - 1
214
+ # scalar decay
215
+ if USE_G:
216
+ if HEAD_FIRST:
217
+ b_g_last = tl.load(g + i_nh * T + last_idx)
218
+ else:
219
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
220
+ b_h *= exp(b_g_last)
221
+
222
+ # vector decay, h = Diag(gk) @ h
223
+ if USE_GK:
224
+ if HEAD_FIRST:
225
+ p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
226
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
227
+ else:
228
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
229
+
230
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
231
+ b_h *= exp(b_gk_last)[:, None]
232
+
233
+ # vector decay, h = h @ Diag(gv)
234
+ if USE_GV:
235
+ if HEAD_FIRST:
236
+ p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
237
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
238
+ else:
239
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
240
+
241
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
242
+ b_h *= exp(b_gv_last)[None, :]
243
+
244
+ if NS > 1:
245
+ if STORE_FINAL_STATE:
246
+ p_hs = tl.make_block_ptr(hs + ((boh + NS-1) * H + i_h)*K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
247
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
248
+ b_h += tl.load(p_hs, boundary_check=(0, 1)).to(tl.float32)
249
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
250
+
251
+
252
+ @triton.heuristics({
253
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
254
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
255
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
256
+ })
257
+ @triton.autotune(
258
+ configs=[
259
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
260
+ for BK in [32, 64]
261
+ for BV in [32, 64]
262
+ for num_warps in [2, 4, 8]
263
+ for num_stages in [2, 3]
264
+ ],
265
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV'],
266
+ )
267
+ @triton.jit(do_not_specialize=['T'])
268
+ def chunk_bwd_kernel_dh_split(
269
+ q,
270
+ g,
271
+ gk,
272
+ gv,
273
+ do,
274
+ dht,
275
+ dhs,
276
+ dhr,
277
+ dh0,
278
+ offsets,
279
+ split_indices,
280
+ scale,
281
+ T,
282
+ S: tl.constexpr,
283
+ HQ: tl.constexpr,
284
+ H: tl.constexpr,
285
+ K: tl.constexpr,
286
+ V: tl.constexpr,
287
+ BT: tl.constexpr,
288
+ BK: tl.constexpr,
289
+ BV: tl.constexpr,
290
+ NG: tl.constexpr,
291
+ USE_G: tl.constexpr,
292
+ USE_GK: tl.constexpr,
293
+ USE_GV: tl.constexpr,
294
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
295
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
296
+ USE_OFFSETS: tl.constexpr,
297
+ HEAD_FIRST: tl.constexpr
298
+ ):
299
+ # handle one split at a time
300
+ # i_h: head index
301
+ # i_n: sequence index
302
+ # i_s: local split index inside a sequence
303
+ i_k, i_v, i_sh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
304
+ i_ss, i_hq = i_sh // HQ, i_sh % HQ
305
+ if USE_OFFSETS:
306
+ i_n, i_s = tl.load(split_indices + i_ss * 2).to(tl.int32), tl.load(split_indices + i_ss * 2 + 1).to(tl.int32)
307
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
308
+ T = eos - bos
309
+ NS = tl.cdiv(T, S)
310
+ else:
311
+ NS = tl.cdiv(T, S)
312
+ i_n, i_s = i_ss // NS, i_ss % NS
313
+ bos, eos = i_n * T, i_n * T + T
314
+ i_nh = i_n * HQ + i_hq
315
+ i_ng, i_h = i_nh // NG, i_hq // NG
316
+
317
+ # [BK, BV]
318
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
319
+ if i_s == NS - 1:
320
+ if USE_FINAL_STATE_GRADIENT:
321
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
322
+ b_dh += tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32)
323
+ p_dhr = tl.make_block_ptr(dhr + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
324
+ tl.store(p_dhr, b_dh.to(p_dhr.dtype.element_ty), boundary_check=(0, 1))
325
+
326
+ for i_t in range(tl.cdiv(min(i_s * S + S, T), BT) - 1, tl.cdiv(i_s * S, BT) - 1, -1):
327
+ if HEAD_FIRST:
328
+ p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
329
+ p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
330
+ else:
331
+ p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
332
+ p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
333
+
334
+ b_q = tl.load(p_q, boundary_check=(0, 1))
335
+ b_q = (b_q * scale).to(b_q.dtype)
336
+ # [BT, BV]
337
+ b_do = tl.load(p_do, boundary_check=(0, 1))
338
+
339
+ last_idx = min(i_t * BT + BT, T) - 1
340
+ if USE_G:
341
+ if HEAD_FIRST:
342
+ p_g = g + i_ng * T + i_t * BT + tl.arange(0, BT)
343
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
344
+ b_g_last = tl.load(g + i_ng * T + last_idx)
345
+ else:
346
+ p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h
347
+ b_g_last = tl.load(g + (bos + last_idx) * H + i_h)
348
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
349
+ b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype)
350
+ b_dh *= exp(b_g_last)
351
+
352
+ if USE_GK:
353
+ if HEAD_FIRST:
354
+ p_gk = tl.make_block_ptr(gk + i_ng * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
355
+ p_gk_last = gk + (i_ng * T + last_idx) * K + i_k * BK + tl.arange(0, BK)
356
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
357
+ else:
358
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
359
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
360
+
361
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
362
+ b_q = (b_q * exp(b_gk)).to(b_q.dtype)
363
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
364
+ b_dh *= exp(b_gk_last)[:, None]
365
+
366
+ if USE_GV:
367
+ if HEAD_FIRST:
368
+ p_gv = tl.make_block_ptr(gv + i_ng * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
369
+ p_gv_last = gv + (i_ng * T + last_idx) * V + i_v * BV + tl.arange(0, BV)
370
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
371
+ else:
372
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
373
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
374
+
375
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
376
+ b_do = (b_do * exp(b_gv)).to(b_do.dtype)
377
+
378
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
379
+ b_dh *= exp(b_gv_last)[None, :]
380
+
381
+ b_dh += tl.dot(b_q, b_do)
382
+
383
+ if NS > 1:
384
+ p_dhs = tl.make_block_ptr(dhs + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
385
+ tl.store(p_dhs, b_dh.to(p_dhs.dtype.element_ty), boundary_check=(0, 1))
386
+ elif STORE_INITIAL_STATE_GRADIENT:
387
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
388
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
389
+
390
+
391
+ @triton.heuristics({
392
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
393
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
394
+ })
395
+ @triton.autotune(
396
+ configs=[
397
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
398
+ for BK in [32, 64]
399
+ for BV in [32, 64]
400
+ for num_warps in [2, 4, 8]
401
+ for num_stages in [2, 3, 4]
402
+ ],
403
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV'],
404
+ )
405
+ @triton.jit(do_not_specialize=['T'])
406
+ def chunk_bwd_kernel_dh_reduction(
407
+ g,
408
+ gk,
409
+ gv,
410
+ dhs,
411
+ dhr,
412
+ dh0,
413
+ offsets,
414
+ split_offsets,
415
+ T,
416
+ S: tl.constexpr,
417
+ H: tl.constexpr,
418
+ HQ: tl.constexpr,
419
+ K: tl.constexpr,
420
+ V: tl.constexpr,
421
+ BT: tl.constexpr,
422
+ BK: tl.constexpr,
423
+ BV: tl.constexpr,
424
+ NG: tl.constexpr,
425
+ USE_G: tl.constexpr,
426
+ USE_GK: tl.constexpr,
427
+ USE_GV: tl.constexpr,
428
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
429
+ USE_OFFSETS: tl.constexpr,
430
+ HEAD_FIRST: tl.constexpr
431
+ ):
432
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
433
+ i_n, i_hq = i_nh // HQ, i_nh % HQ
434
+ i_ng, i_h = i_nh // NG, i_hq // NG
435
+ if USE_OFFSETS:
436
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
437
+ T = eos - bos
438
+ NS = tl.cdiv(T, S)
439
+ boh = tl.load(split_offsets + i_n).to(tl.int32)
440
+ else:
441
+ bos, eos = i_n * T, i_n * T + T
442
+ NS = tl.cdiv(T, S)
443
+ boh = i_n * NS
444
+
445
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
446
+ for i_s in range(NS - 2, -1, -1):
447
+ p_dhs = tl.make_block_ptr(dhs + ((boh+i_s+1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
448
+ p_dhr = tl.make_block_ptr(dhr + ((boh+i_s) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
449
+ b_dh += tl.load(p_dhs, boundary_check=(0, 1)).to(tl.float32)
450
+ tl.store(p_dhr, b_dh.to(p_dhr.dtype.element_ty), boundary_check=(0, 1))
451
+
452
+ for i_t in range(tl.cdiv(min(i_s * S + S, T), BT) - 1, tl.cdiv(i_s * S, BT) - 1, -1):
453
+ last_idx = min(i_t * BT + BT, T) - 1
454
+ # scalar decay
455
+ if USE_G:
456
+ if HEAD_FIRST:
457
+ b_g_last = tl.load(g + i_ng * T + last_idx)
458
+ else:
459
+ b_g_last = tl.load(g + (bos + last_idx) * H + i_h)
460
+ b_dh *= exp(b_g_last)
461
+
462
+ if USE_GK:
463
+ if HEAD_FIRST:
464
+ p_gk_last = gk + (i_ng * T + last_idx) * K + i_k * BK + tl.arange(0, BK)
465
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
466
+ else:
467
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
468
+
469
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
470
+ b_dh *= exp(b_gk_last)[:, None]
471
+
472
+ if USE_GV:
473
+ if HEAD_FIRST:
474
+ p_gv_last = gv + (i_ng * T + last_idx) * V + i_v * BV + tl.arange(0, BV)
475
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
476
+ else:
477
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
478
+
479
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
480
+ b_dh *= exp(b_gv_last)[None, :]
481
+
482
+ if NS > 1:
483
+ if STORE_INITIAL_STATE_GRADIENT:
484
+ p_dhs = tl.make_block_ptr(dhs + (boh * H + i_h)*K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
485
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
486
+ b_dh += tl.load(p_dhs, boundary_check=(0, 1)).to(tl.float32)
487
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
488
+
489
+
490
+ def chunk_fwd_h(
491
+ k: torch.Tensor,
492
+ v: torch.Tensor,
493
+ g: torch.Tensor,
494
+ gk: torch.Tensor,
495
+ gv: torch.Tensor,
496
+ h0: torch.Tensor,
497
+ output_final_state: bool,
498
+ offsets: Optional[torch.LongTensor] = None,
499
+ split_offsets: Optional[torch.LongTensor] = None,
500
+ split_indices: Optional[torch.LongTensor] = None,
501
+ head_first: bool = True,
502
+ chunk_size: int = 64,
503
+ split_size: int = 256,
504
+ states_in_fp32: bool = True
505
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
506
+ if head_first:
507
+ B, H, T, K, V = *k.shape, v.shape[-1]
508
+ else:
509
+ B, T, H, K, V = *k.shape, v.shape[-1]
510
+ # B: batch size
511
+ # N: the actual number of sequences in the batch
512
+ # H: number of heads
513
+ # T: sequence length, can be variable across sequences
514
+ # S: split size, a multiple of chunk size
515
+ # BT: chunk size
516
+ S, BT = split_size, chunk_size
517
+ assert S % BT == 0, f"The `split_size` (got {S}) must be a multiple of `chunk_size` {BT}"
518
+ if offsets is None:
519
+ N = B
520
+ NS = N * triton.cdiv(T, S)
521
+ else:
522
+ N = len(offsets) - 1
523
+ NS = split_offsets[-1]
524
+
525
+ # unreduced kv states per split
526
+ hs = k.new_empty(NS, H, K, V, dtype=torch.float)
527
+ # reduced states per split
528
+ hr = k.new_empty(NS, H, K, V, dtype=torch.float if states_in_fp32 else k.dtype)
529
+ ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None
530
+ # parallelized over splits
531
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), NS * H)
532
+ chunk_fwd_kernel_h_split[grid](
533
+ k=k,
534
+ v=v,
535
+ g=g,
536
+ gk=gk,
537
+ gv=gv,
538
+ hs=hs,
539
+ hr=hr,
540
+ h0=h0,
541
+ ht=ht,
542
+ offsets=offsets,
543
+ split_indices=split_indices,
544
+ T=T,
545
+ S=S,
546
+ H=H,
547
+ K=K,
548
+ V=V,
549
+ BT=BT,
550
+ USE_G=g is not None,
551
+ USE_GK=gk is not None,
552
+ USE_GV=gv is not None,
553
+ HEAD_FIRST=head_first
554
+ )
555
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H)
556
+ chunk_fwd_kernel_h_reduction[grid](
557
+ g=g,
558
+ gk=gk,
559
+ gv=gv,
560
+ hs=hs,
561
+ hr=hr,
562
+ ht=ht,
563
+ offsets=offsets,
564
+ split_offsets=split_offsets,
565
+ T=T,
566
+ S=S,
567
+ H=H,
568
+ K=K,
569
+ V=V,
570
+ BT=BT,
571
+ USE_G=g is not None,
572
+ USE_GK=gk is not None,
573
+ USE_GV=gv is not None,
574
+ HEAD_FIRST=head_first
575
+ )
576
+ return hr, ht
577
+
578
+
579
+ def chunk_bwd_dh(
580
+ q: torch.Tensor,
581
+ k: torch.Tensor,
582
+ v: torch.Tensor,
583
+ g: torch.Tensor,
584
+ gk: torch.Tensor,
585
+ gv: torch.Tensor,
586
+ do: torch.Tensor,
587
+ h0: torch.Tensor,
588
+ dht: torch.Tensor,
589
+ scale: float,
590
+ offsets: Optional[torch.Tensor] = None,
591
+ split_offsets: Optional[torch.Tensor] = None,
592
+ split_indices: Optional[torch.Tensor] = None,
593
+ head_first: bool = True,
594
+ chunk_size: int = 64,
595
+ split_size: int = 256,
596
+ states_in_fp32: bool = True
597
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
598
+ if head_first:
599
+ B, H, T, K, V = *k.shape, v.shape[-1]
600
+ HQ = q.shape[1]
601
+ else:
602
+ B, T, H, K, V = *k.shape, v.shape[-1]
603
+ HQ = q.shape[2]
604
+ # B: batch size
605
+ # N: the actual number of sequences in the batch
606
+ # H: number of heads
607
+ # T: sequence length, can be variable across sequences
608
+ # S: split size, a multiple of chunk size
609
+ # BT: chunk size
610
+ S, BT = max(chunk_size, min(split_size, triton.next_power_of_2(T))), chunk_size
611
+ assert S % BT == 0, f"The `split_size` (got {S}) must be a multiple of `chunk_size` {BT}"
612
+ if offsets is None:
613
+ N = B
614
+ NS = N * triton.cdiv(T, S)
615
+ else:
616
+ N = len(offsets) - 1
617
+ NS = split_offsets[-1]
618
+ # number of groups in GQA
619
+ NG = HQ // H
620
+
621
+ dhs = q.new_empty(NS, HQ, K, V, dtype=torch.float)
622
+ dhr = q.new_empty(NS, HQ, K, V, dtype=torch.float if states_in_fp32 else k.dtype)
623
+ dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None
624
+
625
+ # parallelized over splits
626
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), NS * HQ)
627
+ chunk_bwd_kernel_dh_split[grid](
628
+ q=q,
629
+ g=g,
630
+ gk=gk,
631
+ gv=gv,
632
+ do=do,
633
+ dht=dht,
634
+ dhs=dhs,
635
+ dhr=dhr,
636
+ dh0=dh0,
637
+ offsets=offsets,
638
+ split_indices=split_indices,
639
+ scale=scale,
640
+ T=T,
641
+ S=S,
642
+ HQ=HQ,
643
+ H=H,
644
+ K=K,
645
+ V=V,
646
+ BT=BT,
647
+ NG=NG,
648
+ USE_G=g is not None,
649
+ USE_GK=gk is not None,
650
+ USE_GV=gv is not None,
651
+ HEAD_FIRST=head_first,
652
+ )
653
+
654
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * HQ)
655
+ chunk_bwd_kernel_dh_reduction[grid](
656
+ g=g,
657
+ gk=gk,
658
+ gv=gv,
659
+ dhs=dhs,
660
+ dhr=dhr,
661
+ dh0=dh0,
662
+ offsets=offsets,
663
+ split_offsets=split_offsets,
664
+ T=T,
665
+ S=S,
666
+ HQ=HQ,
667
+ H=H,
668
+ K=K,
669
+ V=V,
670
+ BT=BT,
671
+ NG=NG,
672
+ USE_G=g is not None,
673
+ USE_GK=gk is not None,
674
+ USE_GV=gv is not None,
675
+ HEAD_FIRST=head_first
676
+ )
677
+ return dhr, dh0
fla/ops/common/chunk_o.py ADDED
@@ -0,0 +1,668 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp, safe_exp
11
+ from fla.utils import check_shared_mem, is_nvidia_hopper
12
+
13
+ BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
14
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_G': lambda args: args['g'] is not None,
19
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
20
+ })
21
+ @triton.autotune(
22
+ configs=[
23
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
24
+ for BK in BKV_LIST
25
+ for BV in BKV_LIST
26
+ for num_warps in NUM_WARPS
27
+ for num_stages in [2, 3, 4]
28
+ ],
29
+ key=['H', 'K', 'V', 'BT'],
30
+ )
31
+ @triton.jit(do_not_specialize=['T'])
32
+ def chunk_fwd_kernel_o(
33
+ q,
34
+ k,
35
+ v,
36
+ h,
37
+ g,
38
+ o,
39
+ offsets,
40
+ indices,
41
+ scale,
42
+ T,
43
+ H: tl.constexpr,
44
+ K: tl.constexpr,
45
+ V: tl.constexpr,
46
+ BT: tl.constexpr,
47
+ BK: tl.constexpr,
48
+ BV: tl.constexpr,
49
+ USE_G: tl.constexpr,
50
+ USE_OFFSETS: tl.constexpr,
51
+ HEAD_FIRST: tl.constexpr
52
+ ):
53
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
54
+ i_b, i_h = i_bh // H, i_bh % H
55
+
56
+ if USE_OFFSETS:
57
+ i_tg = i_t
58
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
59
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
60
+ T = eos - bos
61
+ NT = tl.cdiv(T, BT)
62
+ else:
63
+ NT = tl.cdiv(T, BT)
64
+ i_tg = i_b * NT + i_t
65
+ bos, eos = i_b * T, i_b * T + T
66
+
67
+ s_qk = K if HEAD_FIRST else H*K
68
+ s_vo = V if HEAD_FIRST else H*V
69
+ s_g = 1 if HEAD_FIRST else H
70
+ # offset calculation
71
+ q += (i_bh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
72
+ k += (i_bh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
73
+ v += (i_bh * T*V) if HEAD_FIRST else ((bos * H + i_h) * V)
74
+ o += (i_bh * T*V) if HEAD_FIRST else ((bos * H + i_h) * V)
75
+ h += ((i_bh * NT + i_t).to(tl.int64) * K*V) if HEAD_FIRST else ((i_tg * H + i_h).to(tl.int64) * K*V)
76
+
77
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
78
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
79
+
80
+ for i_k in range(tl.cdiv(K, BK)):
81
+ p_q = tl.make_block_ptr(q, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
82
+ p_k = tl.make_block_ptr(k, (K, T), (1, s_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
83
+ p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
84
+ # [BT, BK]
85
+ b_q = tl.load(p_q, boundary_check=(0, 1))
86
+ # [BK, BT]
87
+ b_k = tl.load(p_k, boundary_check=(0, 1))
88
+ # [BK, BV]
89
+ b_h = tl.load(p_h, boundary_check=(0, 1))
90
+
91
+ # [BT, BK] @ [BK, BV] -> [BT, BV]
92
+ b_o += tl.dot(b_q, b_h)
93
+ # [BT, BK] @ [BK, BT] -> [BT, BT]
94
+ b_A += tl.dot(b_q, b_k)
95
+
96
+ if USE_G:
97
+ g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h)
98
+ p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,))
99
+ b_g = tl.load(p_g, boundary_check=(0,))
100
+ b_o = b_o * exp(b_g)[:, None]
101
+ b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :])
102
+
103
+ o_i = tl.arange(0, BT)
104
+ m_A = o_i[:, None] >= o_i[None, :]
105
+ b_A = tl.where(m_A, b_A, 0)
106
+
107
+ p_v = tl.make_block_ptr(v, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
108
+ p_o = tl.make_block_ptr(o, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
109
+ b_v = tl.load(p_v, boundary_check=(0, 1))
110
+
111
+ # to fix mma -> mma layout conversion
112
+ # already solved by triton v3.2 or higher
113
+ b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale
114
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
115
+
116
+
117
+ @triton.heuristics({
118
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
119
+ 'USE_G': lambda args: args['g'] is not None,
120
+ 'USE_DW': lambda args: args['dw'] is not None
121
+ })
122
+ @triton.autotune(
123
+ configs=[
124
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
125
+ for num_warps in NUM_WARPS
126
+ for num_stages in [2, 3, 4]
127
+ ],
128
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G', 'USE_DW'],
129
+ )
130
+ @triton.jit(do_not_specialize=['T'])
131
+ def chunk_bwd_kernel_dqkwg(
132
+ q,
133
+ k,
134
+ v,
135
+ h,
136
+ g,
137
+ do,
138
+ dh,
139
+ dq,
140
+ dk,
141
+ dg,
142
+ w,
143
+ dv,
144
+ dw,
145
+ offsets,
146
+ indices,
147
+ scale,
148
+ B: tl.constexpr,
149
+ T,
150
+ H: tl.constexpr,
151
+ K: tl.constexpr,
152
+ V: tl.constexpr,
153
+ BT: tl.constexpr,
154
+ BK: tl.constexpr,
155
+ BV: tl.constexpr,
156
+ USE_G: tl.constexpr,
157
+ USE_DW: tl.constexpr,
158
+ USE_OFFSETS: tl.constexpr,
159
+ HEAD_FIRST: tl.constexpr
160
+ ):
161
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
162
+ i_b, i_h = i_bh // H, i_bh % H
163
+ if USE_G:
164
+ dg += i_k * B * H * T
165
+ if USE_OFFSETS:
166
+ i_tg = i_t
167
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
168
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
169
+ T = eos - bos
170
+ NT = tl.cdiv(T, BT)
171
+ else:
172
+ NT = tl.cdiv(T, BT)
173
+ i_tg = i_b * NT + i_t
174
+ bos, eos = i_b * T, i_b * T + T
175
+
176
+ # offset calculation
177
+ v += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
178
+ do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
179
+ h += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V
180
+ dh += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V
181
+ q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
182
+ k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
183
+ dq += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
184
+ dk += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
185
+ s_qk = K if HEAD_FIRST else H*K
186
+ s_vo = V if HEAD_FIRST else H*V
187
+ s_g = 1 if HEAD_FIRST else H
188
+
189
+ # for delta rule only
190
+ if USE_DW:
191
+ dw += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
192
+ dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
193
+ w += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
194
+
195
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
196
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
197
+ b_ds = tl.zeros([BT, BT], dtype=tl.float32)
198
+ b_dg_last = tl.zeros([1,], dtype=tl.float32) if USE_G else None
199
+ b_dw = tl.zeros([BT, BK], dtype=tl.float32) if USE_DW else None
200
+
201
+ for i_v in range(tl.cdiv(V, BV)):
202
+ p_v = tl.make_block_ptr(v, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
203
+ p_do = tl.make_block_ptr(do, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
204
+ p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
205
+ p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
206
+ # [BT, BV]
207
+ b_v = tl.load(p_v, boundary_check=(0, 1))
208
+ b_do = tl.load(p_do, boundary_check=(0, 1))
209
+ # [BV, BK]
210
+ b_h = tl.load(p_h, boundary_check=(0, 1))
211
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
212
+ if USE_G:
213
+ b_dg_last += (tl.sum(b_h * b_dh))
214
+ # [BT, BV] @ [BV, BT] -> [BT, BT]
215
+ b_ds += tl.dot(b_do, tl.trans(b_v))
216
+ # [BT, BV] @ [BV, BK] -> [BT, BK]
217
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype))
218
+ # [BT, BV] @ [BV, BK] -> [BT, BK]
219
+ b_dk += tl.dot(b_v, b_dh.to(b_v.dtype))
220
+ if USE_DW:
221
+ p_dv = tl.make_block_ptr(dv, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
222
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
223
+ b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype))
224
+
225
+ if USE_DW and not USE_G:
226
+ p_dw = tl.make_block_ptr(dw, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
227
+ tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
228
+
229
+ tl.debug_barrier()
230
+ o_i = tl.arange(0, BT)
231
+ p_q = tl.make_block_ptr(q, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
232
+ p_k = tl.make_block_ptr(k, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
233
+ b_q = tl.load(p_q, boundary_check=(0, 1))
234
+ b_k = tl.load(p_k, boundary_check=(0, 1))
235
+
236
+ p_dq = tl.make_block_ptr(dq, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
237
+ p_dk = tl.make_block_ptr(dk, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
238
+
239
+ if USE_G:
240
+ b_dg = tl.zeros([BT,], dtype=tl.float32)
241
+ g += i_bh * T if HEAD_FIRST else bos * H + i_h
242
+ dg += i_bh * T if HEAD_FIRST else bos * H + i_h
243
+ p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,))
244
+ b_g = tl.load(p_g, boundary_check=(0,))
245
+ b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * s_g)
246
+ b_dg_last *= exp(b_g_last)
247
+
248
+ if USE_DW:
249
+ p_w = tl.make_block_ptr(w, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
250
+ p_dw = tl.make_block_ptr(dw, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
251
+ b_w = tl.load(p_w, boundary_check=(0, 1))
252
+ b_dw = b_dw * exp(b_g)[:, None]
253
+ tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
254
+ b_dg -= tl.sum(b_w * b_dw, axis=1)
255
+
256
+ b_dq = b_dq * exp(b_g)[:, None] * scale
257
+ b_dg += tl.sum(b_dq * b_q, axis=1)
258
+
259
+ b_dk = b_dk * safe_exp(-b_g + b_g_last)[:, None]
260
+ b_dg -= tl.sum(b_k * b_dk, axis=1)
261
+ b_dg_last += tl.sum(b_dk * b_k)
262
+
263
+ b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds * safe_exp(b_g[:, None] - b_g[None, :]), 0) * scale
264
+ b_ds2 = b_ds * tl.dot(b_q, tl.trans(b_k))
265
+ b_dg += tl.sum(b_ds2, axis=1)
266
+ b_dg -= tl.sum(b_ds2, axis=0)
267
+
268
+ b_ds = b_ds.to(b_k.dtype)
269
+ # [BT, BK]
270
+ b_dq += tl.dot(b_ds, b_k)
271
+ b_dk += tl.dot(tl.trans(b_ds), b_q)
272
+ p_dg = tl.make_block_ptr(dg, (T,), (s_g,), (i_t * BT,), (BT,), (0,))
273
+ # (SY 09/21) revcumsum in a separate kernel due to strange triton compiler issue
274
+ # b_dg = tl.dot(tl.where(o_i[:, None] <= o_i[None, :], 1., 0.), b_dg, allow_tf32=False) + b_dg_last)
275
+ b_dg = tl.where(o_i < min(BT, T-i_t*BT) - 1, b_dg, b_dg + b_dg_last)
276
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
277
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
278
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
279
+ else:
280
+ b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0)
281
+ b_ds = b_ds.to(b_k.dtype)
282
+ b_dq += tl.dot(b_ds, b_k)
283
+ b_dk += tl.dot(tl.trans(b_ds), b_q) * scale
284
+ b_dq *= scale
285
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
286
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
287
+
288
+
289
+ @triton.heuristics({
290
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
291
+ 'USE_G': lambda args: args['g'] is not None,
292
+ })
293
+ @triton.autotune(
294
+ configs=[
295
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
296
+ for num_warps in [2, 4, 8]
297
+ for num_stages in [2, 3, 4]
298
+ ],
299
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'],
300
+ )
301
+ @triton.jit(do_not_specialize=['T'])
302
+ def chunk_bwd_kernel_dv(
303
+ q,
304
+ k,
305
+ g,
306
+ do,
307
+ dv,
308
+ dh,
309
+ offsets,
310
+ indices,
311
+ scale,
312
+ T,
313
+ H: tl.constexpr,
314
+ K: tl.constexpr,
315
+ V: tl.constexpr,
316
+ BT: tl.constexpr,
317
+ BK: tl.constexpr,
318
+ BV: tl.constexpr,
319
+ USE_G: tl.constexpr,
320
+ USE_OFFSETS: tl.constexpr,
321
+ HEAD_FIRST: tl.constexpr
322
+ ):
323
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
324
+ i_b, i_h = i_bh // H, i_bh % H
325
+ if USE_OFFSETS:
326
+ i_tg = i_t
327
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
328
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
329
+ T = eos - bos
330
+ NT = tl.cdiv(T, BT)
331
+ else:
332
+ NT = tl.cdiv(T, BT)
333
+ i_tg = i_b * NT + i_t
334
+ bos, eos = i_b * T, i_b * T + T
335
+
336
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
337
+
338
+ # offset calculation
339
+ q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
340
+ k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
341
+ do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
342
+ dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
343
+ s_qk = K if HEAD_FIRST else H*K
344
+ s_vo = V if HEAD_FIRST else H*V
345
+ s_g = 1 if HEAD_FIRST else H
346
+ dh += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V
347
+
348
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
349
+ for i_k in range(tl.cdiv(K, BK)):
350
+ p_k = tl.make_block_ptr(k, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
351
+ p_q = tl.make_block_ptr(q, (K, T), (1, s_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
352
+ b_q = tl.load(p_q, boundary_check=(0, 1))
353
+ b_k = tl.load(p_k, boundary_check=(0, 1))
354
+ b_A += tl.dot(b_k, b_q)
355
+ p_dh = tl.make_block_ptr(dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
356
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
357
+ b_dv += tl.dot(b_k, b_dh.to(b_k.dtype))
358
+
359
+ if USE_G:
360
+ g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h)
361
+ p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,))
362
+ b_g = tl.load(p_g, boundary_check=(0,))
363
+ b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * s_g)
364
+ b_dv *= safe_exp(-b_g + b_g_last)[:, None]
365
+
366
+ mask = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :])
367
+ if USE_G:
368
+ b_A = tl.where(mask, b_A * safe_exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty)
369
+ else:
370
+ b_A = tl.where(mask, b_A * scale, 0).to(do.dtype.element_ty)
371
+ p_do = tl.make_block_ptr(do, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
372
+ p_dv = tl.make_block_ptr(dv, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
373
+ b_do = tl.load(p_do, boundary_check=(0, 1))
374
+ b_dv += tl.dot(b_A.to(b_do.dtype), b_do)
375
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
376
+
377
+
378
+ @triton.heuristics({
379
+ 'USE_G': lambda args: args['g'] is not None,
380
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
381
+ })
382
+ @triton.autotune(
383
+ configs=[
384
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
385
+ for num_warps in NUM_WARPS
386
+ for num_stages in [2, 3, 4]
387
+ ],
388
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'],
389
+ )
390
+ @triton.jit(do_not_specialize=['T'])
391
+ def chunk_bwd_kernel_dv_local(
392
+ q,
393
+ k,
394
+ g,
395
+ do,
396
+ dv,
397
+ offsets,
398
+ indices,
399
+ scale,
400
+ T,
401
+ H: tl.constexpr,
402
+ K: tl.constexpr,
403
+ V: tl.constexpr,
404
+ BT: tl.constexpr,
405
+ BK: tl.constexpr,
406
+ BV: tl.constexpr,
407
+ USE_G: tl.constexpr,
408
+ USE_OFFSETS: tl.constexpr,
409
+ HEAD_FIRST: tl.constexpr
410
+ ):
411
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
412
+ i_b, i_h = i_bh // H, i_bh % H
413
+ if USE_OFFSETS:
414
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
415
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
416
+ T = eos - bos
417
+ else:
418
+ bos, eos = i_b * T, i_b * T + T
419
+
420
+ # offset calculation
421
+ q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
422
+ k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
423
+ do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
424
+ dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
425
+ s_qk = K if HEAD_FIRST else H*K
426
+ s_vo = V if HEAD_FIRST else H*V
427
+ s_g = 1 if HEAD_FIRST else H
428
+
429
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
430
+ for i_k in range(tl.cdiv(K, BK)):
431
+ p_k = tl.make_block_ptr(k, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
432
+ p_q = tl.make_block_ptr(q, (K, T), (1, s_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
433
+ b_q = tl.load(p_q, boundary_check=(0, 1))
434
+ b_k = tl.load(p_k, boundary_check=(0, 1))
435
+ b_A += tl.dot(b_k, b_q)
436
+
437
+ if USE_G:
438
+ g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h)
439
+ p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,))
440
+ b_g = tl.load(p_g, boundary_check=(0,))
441
+
442
+ mask = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :])
443
+ if USE_G:
444
+ b_A = tl.where(mask, b_A * safe_exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty)
445
+ else:
446
+ b_A = tl.where(mask, b_A * scale, 0).to(do.dtype.element_ty)
447
+
448
+ for i_v in range(tl.cdiv(V, BV)):
449
+ p_do = tl.make_block_ptr(do, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
450
+ p_dv = tl.make_block_ptr(dv, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
451
+ b_do = tl.load(p_do, boundary_check=(0, 1))
452
+ b_dv = tl.dot(b_A.to(b_do.dtype), b_do)
453
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
454
+
455
+
456
+ def chunk_fwd_o(
457
+ q: torch.Tensor,
458
+ k: torch.Tensor,
459
+ v: torch.Tensor,
460
+ h: torch.Tensor,
461
+ g: Optional[torch.Tensor] = None, # cumsum of log decay
462
+ scale: Optional[float] = None,
463
+ offsets: Optional[torch.LongTensor] = None,
464
+ indices: Optional[torch.LongTensor] = None,
465
+ head_first: bool = True,
466
+ chunk_size: int = 64
467
+ ) -> torch.Tensor:
468
+ if head_first:
469
+ B, H, T, K, V = *q.shape, v.shape[-1]
470
+ else:
471
+ B, T, H, K, V = *q.shape, v.shape[-1]
472
+ if scale is None:
473
+ scale = k.shape[-1] ** -0.5
474
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
475
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
476
+
477
+ o = torch.empty_like(v)
478
+
479
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
480
+ chunk_fwd_kernel_o[grid](
481
+ q,
482
+ k,
483
+ v,
484
+ h,
485
+ g,
486
+ o,
487
+ offsets,
488
+ indices,
489
+ scale,
490
+ T=T,
491
+ H=H,
492
+ K=K,
493
+ V=V,
494
+ BT=BT,
495
+ HEAD_FIRST=head_first
496
+ )
497
+ return o
498
+
499
+
500
+ def chunk_bwd_dv(
501
+ q: torch.Tensor,
502
+ k: torch.Tensor,
503
+ g: torch.Tensor,
504
+ do: torch.Tensor,
505
+ dh: torch.Tensor,
506
+ scale: float,
507
+ offsets: Optional[torch.LongTensor] = None,
508
+ indices: Optional[torch.LongTensor] = None,
509
+ head_first: bool = True,
510
+ chunk_size: int = 64
511
+ ) -> torch.Tensor:
512
+ if head_first:
513
+ B, H, T, K, V = *k.shape, do.shape[-1]
514
+ else:
515
+ B, T, H, K, V = *k.shape, do.shape[-1]
516
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
517
+ # H100 can have larger block size
518
+ if check_shared_mem('hopper', k.device.index):
519
+ CONST_TILING = 128
520
+ elif check_shared_mem:
521
+ CONST_TILING = 64
522
+ else:
523
+ CONST_TILING = 32
524
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
525
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
526
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
527
+ NV = triton.cdiv(V, BV)
528
+
529
+ dv = torch.empty_like(do)
530
+ grid = (NV, NT, B * H)
531
+ chunk_bwd_kernel_dv[grid](
532
+ q,
533
+ k,
534
+ g,
535
+ do,
536
+ dv,
537
+ dh,
538
+ offsets,
539
+ indices,
540
+ scale,
541
+ T=T,
542
+ H=H,
543
+ K=K,
544
+ V=V,
545
+ BT=BT,
546
+ BK=BK,
547
+ BV=BV,
548
+ HEAD_FIRST=head_first
549
+ )
550
+ return dv
551
+
552
+
553
+ def chunk_bwd_dv_local(
554
+ q: torch.Tensor,
555
+ k: torch.Tensor,
556
+ g: torch.Tensor,
557
+ do: torch.Tensor,
558
+ dh: torch.Tensor,
559
+ scale: float,
560
+ offsets: Optional[torch.LongTensor] = None,
561
+ indices: Optional[torch.LongTensor] = None,
562
+ head_first: bool = True,
563
+ chunk_size: int = 64
564
+ ) -> torch.Tensor:
565
+ if head_first:
566
+ B, H, T, K, V = *k.shape, do.shape[-1]
567
+ else:
568
+ B, T, H, K, V = *k.shape, do.shape[-1]
569
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
570
+ # H100 can have larger block size
571
+ if check_shared_mem('hopper', k.device.index):
572
+ CONST_TILING = 128
573
+ elif check_shared_mem:
574
+ CONST_TILING = 64
575
+ else:
576
+ CONST_TILING = 32
577
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
578
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
579
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
580
+
581
+ dv = torch.empty_like(do)
582
+ grid = (NT, B * H)
583
+ chunk_bwd_kernel_dv_local[grid](
584
+ q,
585
+ k,
586
+ g,
587
+ do,
588
+ dv,
589
+ offsets,
590
+ indices,
591
+ scale,
592
+ T=T,
593
+ H=H,
594
+ K=K,
595
+ V=V,
596
+ BT=BT,
597
+ BK=BK,
598
+ BV=BV,
599
+ HEAD_FIRST=head_first
600
+ )
601
+ return dv
602
+
603
+
604
+ def chunk_bwd_dqkwg(
605
+ q: torch.Tensor,
606
+ k: torch.Tensor,
607
+ v: torch.Tensor,
608
+ g: torch.Tensor,
609
+ do: torch.Tensor,
610
+ h: torch.Tensor,
611
+ dh: torch.Tensor,
612
+ dv: Optional[torch.Tensor] = None,
613
+ w: Optional[torch.Tensor] = None,
614
+ offsets: Optional[torch.LongTensor] = None,
615
+ indices: Optional[torch.LongTensor] = None,
616
+ chunk_size: int = 64,
617
+ scale: float = 1.0,
618
+ head_first: bool = True,
619
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
620
+
621
+ if head_first:
622
+ B, H, T, K, V = *k.shape, v.shape[-1]
623
+ else:
624
+ B, T, H, K, V = *k.shape, v.shape[-1]
625
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
626
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
627
+
628
+ CONST_TILING = 64 if check_shared_mem() else 32
629
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
630
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
631
+ NK = triton.cdiv(K, BK)
632
+ dq = torch.empty_like(q)
633
+ dk = torch.empty_like(k)
634
+ dg = torch.empty(NK, *g.shape, dtype=torch.float32, device=g.device) if g is not None else None
635
+ dw = torch.empty_like(w) if w is not None else None
636
+
637
+ grid = (NK, NT, B * H)
638
+ chunk_bwd_kernel_dqkwg[grid](
639
+ q=q,
640
+ k=k,
641
+ v=v,
642
+ h=h,
643
+ g=g,
644
+ do=do,
645
+ dh=dh,
646
+ dv=dv,
647
+ w=w,
648
+ dw=dw,
649
+ dq=dq,
650
+ dk=dk,
651
+ dg=dg,
652
+ offsets=offsets,
653
+ indices=indices,
654
+ scale=scale,
655
+ B=B,
656
+ T=T,
657
+ H=H,
658
+ K=K,
659
+ V=V,
660
+ BT=BT,
661
+ BK=BK,
662
+ BV=BV,
663
+ HEAD_FIRST=head_first
664
+ )
665
+
666
+ if dg is not None:
667
+ dg = dg.sum(0)
668
+ return dq, dk, dw, dg
fla/ops/common/chunk_scaled_dot_kkt.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_indices
11
+
12
+
13
+ @triton.heuristics({
14
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
15
+ })
16
+ @triton.autotune(
17
+ configs=[
18
+ triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
19
+ for BK in [32, 64, 128]
20
+ for num_warps in [2, 4, 8]
21
+ for num_stages in [2, 3, 4]
22
+ ],
23
+ key=['H', 'K', 'BT', 'USE_OFFSETS'],
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def chunk_scaled_dot_kkt_fwd_kernel(
27
+ k,
28
+ beta,
29
+ A,
30
+ offsets,
31
+ indices,
32
+ T,
33
+ H: tl.constexpr,
34
+ K: tl.constexpr,
35
+ BT: tl.constexpr,
36
+ BK: tl.constexpr,
37
+ HEAD_FIRST: tl.constexpr,
38
+ USE_OFFSETS: tl.constexpr,
39
+ ):
40
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
41
+ i_b, i_h = i_bh // H, i_bh % H
42
+ if USE_OFFSETS:
43
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
44
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
45
+ T = eos - bos
46
+ else:
47
+ bos, eos = i_b * T, i_b * T + T
48
+ o_t = tl.arange(0, BT)
49
+
50
+ if HEAD_FIRST:
51
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
52
+ else:
53
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
54
+ b_beta = tl.load(p_beta, boundary_check=(0,))
55
+
56
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
57
+ for i_k in range(tl.cdiv(K, BK)):
58
+ if HEAD_FIRST:
59
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
60
+ else:
61
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
62
+ b_k = tl.load(p_k, boundary_check=(0, 1))
63
+ b_kb = b_k * b_beta[:, None]
64
+ b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
65
+
66
+ b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
67
+ if HEAD_FIRST:
68
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
69
+ else:
70
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (BT*H, 1), (i_t * BT, 0), (BT, BT), (1, 0))
71
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
72
+
73
+
74
+ def chunk_scaled_dot_kkt_fwd(
75
+ k: torch.Tensor,
76
+ beta: torch.Tensor,
77
+ cu_seqlens: Optional[torch.LongTensor],
78
+ head_first: bool = False,
79
+ chunk_size: int = 64,
80
+ output_dtype: torch.dtype = torch.float32
81
+ ) -> torch.Tensor:
82
+ r"""
83
+ Compute beta * K * K^T.
84
+
85
+ Args:
86
+ k (torch.Tensor):
87
+ The key tensor of shape `[B, T, H, K]` if not `head_first` else `[B, H, T, K]`.
88
+ beta (torch.Tensor):
89
+ The beta tensor of shape `[B, T, H]` if not `head_first` else `[B, H, T]`.
90
+ cu_seqlens (torch.LongTensor):
91
+ The cumulative sequence lengths of the input tensor.
92
+ Default: None
93
+ head_first (bool):
94
+ If False, the input/output tensor is in the shape of `[B, T, H, K]`.
95
+ If True, the input/output tensor is in the shape of `[B, H, T, K]`.
96
+ Default: False
97
+ chunk_size (int):
98
+ The chunk size. Default: 64.
99
+ output_dtype (torch.dtype):
100
+ The dtype of the output tensor. Default: `torch.float32`
101
+
102
+ Returns:
103
+ beta * K * K^T of shape `[B, T, H, BT]` if not `head_first` else `[B, H, T, BT]`,
104
+ where `BT` is the chunk size.
105
+ """
106
+ if head_first:
107
+ B, H, T, K = k.shape
108
+ else:
109
+ B, T, H, K = k.shape
110
+ BT = chunk_size
111
+ indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
112
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(indices)
113
+ A = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=k.device, dtype=output_dtype)
114
+ chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](
115
+ k=k,
116
+ beta=beta,
117
+ A=A,
118
+ offsets=cu_seqlens,
119
+ indices=indices,
120
+ T=T,
121
+ H=H,
122
+ K=K,
123
+ BT=BT,
124
+ HEAD_FIRST=head_first
125
+ )
126
+ return A
fla/ops/common/utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from fla.utils import tensor_cache
9
+
10
+
11
+ @triton.autotune(
12
+ configs=[
13
+ triton.Config({}, num_warps=num_warps)
14
+ for num_warps in [4, 8, 16, 32]
15
+ ],
16
+ key=['B'],
17
+ )
18
+ @triton.jit
19
+ def prepare_position_ids_kernel(
20
+ y,
21
+ offsets,
22
+ B: tl.constexpr
23
+ ):
24
+ i_n = tl.program_id(0)
25
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
26
+ T = eos - bos
27
+
28
+ o = tl.arange(0, B)
29
+ for i in range(0, tl.cdiv(T, B) * B, B):
30
+ o_i = o + i
31
+ tl.store(y + bos + o_i, o_i, o_i < T)
32
+
33
+
34
+ @tensor_cache
35
+ def prepare_lens(offsets: torch.LongTensor) -> torch.LongTensor:
36
+ return offsets[1:] - offsets[:-1]
37
+
38
+
39
+ @tensor_cache
40
+ def prepare_position_ids(offsets: torch.LongTensor) -> torch.LongTensor:
41
+ return torch.cat([torch.arange(n, dtype=offsets.dtype, device=offsets.device) for n in prepare_lens(offsets).unbind()])
42
+
43
+
44
+ @tensor_cache
45
+ def prepare_sequence_ids(position_ids: torch.LongTensor) -> torch.LongTensor:
46
+ return position_ids.eq(0).cumsum(0) - 1
47
+
48
+
49
+ @tensor_cache
50
+ def prepare_token_indices(offsets: torch.LongTensor) -> torch.LongTensor:
51
+ position_ids = prepare_position_ids(offsets)
52
+ return torch.stack([prepare_sequence_ids(position_ids), position_ids], 1).to(offsets)
53
+
54
+
55
+ @tensor_cache
56
+ def prepare_chunk_indices(
57
+ offsets: torch.LongTensor,
58
+ chunk_size: int
59
+ ) -> torch.LongTensor:
60
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(offsets), chunk_size).tolist()])
61
+ return torch.stack([prepare_sequence_ids(indices), indices], 1).to(offsets)
62
+
63
+
64
+ @tensor_cache
65
+ def prepare_chunk_offsets(
66
+ offsets: torch.LongTensor,
67
+ chunk_size: int
68
+ ) -> torch.LongTensor:
69
+ return torch.cat([offsets.new_tensor([0]), triton.cdiv(prepare_lens(offsets), chunk_size)]).cumsum(-1)
fla/ops/delta_rule/README.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Chunkwise-form Parallelism of DeltaNet
2
+
3
+ This section expands on the formulation presented in Appendix B of the DeltaNet paper.[^1]
4
+
5
+ To reduce notational clutter, we focus on the first chunk, denoting $\mathbf{S}^r=\mathbf{S}_{[1]}^r$. By partially expanding the recurrence, we have:
6
+ ```math
7
+ \begin{equation}
8
+ \begin{aligned}
9
+ \mathbf{S}^r &= \underbrace{\left(\prod_{i=1}^r \mathbf{I} - \beta^i \boldsymbol{k}^i \boldsymbol{k}^{i\top} \right)}_{:= \mathbf{P}^r} \cdot\mathbf{S}^{0} + \overbrace{\sum_{i=1}^{r} \underbrace{\left(\prod_{j=i+1}^r \mathbf{I} - \beta^j \boldsymbol{k}^j \boldsymbol{k}^{j\top} \right)}_{:= \mathbf{P}_{i+1}^r}\beta^i \boldsymbol{k}^i\boldsymbol{v}^{i\top}}^{:=\mathbf{H}^r} \\
10
+ &=\mathbf{P}^r \cdot \mathbf{S}^{0} + \mathbf{H}^r
11
+ \end{aligned}
12
+ \end{equation}
13
+ ```
14
+
15
+ where $\mathbf{P}_i^r$ involves cumulative products of generalized Householder matrices.
16
+ We abbreviate $\mathbf{P}_1^r$ as $\mathbf{P}^r$.
17
+ This can be optimized using the classical WY representation:
18
+ ```math
19
+ \begin{equation}
20
+ \mathbf{P}^{r} = \mathbf{I} - \sum_{i=1}^{r}\boldsymbol{k}^i\boldsymbol{w}^{i\top} \in \mathbb{R}^{d_k \times d_k};\qquad
21
+ \boldsymbol{w}^r = \beta^r \left(\boldsymbol{k}^r - \sum_{i=1}^{r-1} \left(\boldsymbol{k}^{r\top}\boldsymbol{k}^i \right)\boldsymbol{w}^i \right) \in \mathbb{R}^{d_k}
22
+ \end{equation}
23
+ ```
24
+
25
+ We prove this by induction:
26
+ ```math
27
+ \begin{align*}
28
+ \mathbf{P}^{r} &= \prod_{i=1}^r \mathbf{I} - \beta^i \boldsymbol{k}^i \boldsymbol{k}^{i\top} \\
29
+ &= \left(\mathbf{I} - \beta^r \boldsymbol{k}^r \boldsymbol{k}^{r\top}\right)\mathbf{P}^{r-1} \\
30
+ &= \left(\mathbf{I} - \beta^r \boldsymbol{k}^r \boldsymbol{k}^{r\top}\right)\left(\mathbf{I} - \sum_{i=1}^{r-1}\boldsymbol{k}^i\boldsymbol{w}^{i\top}\right) \\
31
+ &= \mathbf{I} - \sum_{i=1}^{r-1}\boldsymbol{k}^i\boldsymbol{w}^{i\top} - \beta^r \boldsymbol{k}^r \boldsymbol{k}^{r\top} + \beta^r\boldsymbol{k}^r \boldsymbol{k}^{r\top} \left(\sum_{i=1}^{r-1}\boldsymbol{k}^i\boldsymbol{w}^{i\top}\right) \\
32
+ &= \mathbf{I} - \sum_{i=1}^{r-1}\boldsymbol{k}^i\boldsymbol{w}^{i\top} - \beta^r \boldsymbol{k}^r \left(\boldsymbol{k}^{r} - \left(\sum_{i=1}^{r-1}\left(\boldsymbol{k}^{r\top} \boldsymbol{k}^i\right)\boldsymbol{w}^{i}\right) \right)^\top \\
33
+ &= \mathbf{I} - \sum_{i=1}^{r}\boldsymbol{k}^i\boldsymbol{w}^{i\top}
34
+ \end{align*}
35
+ ```
36
+
37
+ Similarly, $\mathbf{H}^r$ can be represented as:
38
+ ```math
39
+ \begin{equation}
40
+ \mathbf{H}^{r} = \sum_{i=1}^{r} \boldsymbol{k}^i \boldsymbol{u}^{i\top} \in \mathbb{R}^{d_k \times d_v};\qquad \boldsymbol{u}^r = \beta^r \left(\boldsymbol{v}^r - \sum_{i=1}^{r-1} \left(\boldsymbol{k}^{r\top}\boldsymbol{k}^i\right) \boldsymbol{u}^i \right)\in \mathbb{R}^{d_v}
41
+ \end{equation}
42
+ ```
43
+
44
+ This can also be proven by induction:
45
+ ```math
46
+ \begin{align*}
47
+ \mathbf{H}^{r} &= \sum_{i=1}^{r} \mathbf{P}_{i+1}^r \beta^i \boldsymbol{k}^i \boldsymbol{v}^{i\top}\\
48
+ &= \left(\mathbf{I} - \beta^r \boldsymbol{k}^r \boldsymbol{k}^{r\top}\right) \mathbf{H}^{r-1} + \beta^r \boldsymbol{k}^r \boldsymbol{v}^{r\top}\\
49
+ &= \sum_{i=1}^{r-1}\boldsymbol{k}^i \boldsymbol{u}^{i\top} - \beta^r \boldsymbol{k}^r \boldsymbol{k}^{r\top} \sum_{i=1}^{r-1}\boldsymbol{k}^i \boldsymbol{u}^{i\top} +\beta^r \boldsymbol{k}^r \boldsymbol{v}^{r\top}\\
50
+ &= \sum_{i=1}^{r-1}\boldsymbol{k}^i \boldsymbol{u}^{i\top} + \boldsymbol{k}^r \left(\beta^r \boldsymbol{v}^{r\top}-\beta^r \boldsymbol{k}^{r\top} \sum_{i=1}^{r-1}\boldsymbol{k}^i \boldsymbol{u}^{i\top}\right) \\
51
+ &= \sum_{i=1}^{r-1}\boldsymbol{k}^i \boldsymbol{u}^{i\top} + \boldsymbol{k}^r \beta^r\left(\boldsymbol{v}^{r}-\sum_{i=1}^{r-1}\left(\boldsymbol{k}^{r\top}\boldsymbol{k}^{i}\right)\boldsymbol{u}^{i} \right)^\top \\
52
+ &=\sum_{i=1}^{r} \boldsymbol{k}^i \boldsymbol{u}^{i\top}
53
+ \end{align*}
54
+ ```
55
+
56
+ In matrix form, $\mathbf{P}$ and $\mathbf{H}$ can be written as:
57
+ ```math
58
+ \begin{equation}
59
+ \mathbf{P}=\mathbf{I}-\mathbf{K}^\top\mathbf{W} \in \mathbb{R}^{d_k \times d_k}, \qquad\mathbf{H}=\mathbf{K}^\top\mathbf{U} \in \mathbb{R}^{d_k\times d_v}
60
+ \end{equation}
61
+ ```
62
+
63
+ Now we can derive the matrix form of $\mathbf{W}$ and $\mathbf{U}$:
64
+ ```math
65
+ \begin{align*}
66
+ \mathbf{W} &= \mathrm{diag}(\beta) \mathbf{K} - \mathrm{tril}(\mathrm{diag}(\beta) \mathbf{K}\mathbf{K}^\top, -1)\mathbf{W}\\
67
+ \left(\mathbf{I} + \mathrm{tril}(\mathrm{diag}(\beta) \mathbf{K}\mathbf{K}^\top, -1)\right) \mathbf{W} &= \mathrm{diag}(\beta) \mathbf{K}
68
+ \end{align*}
69
+ ```
70
+ A similar process holds for $\mathbf{U}$. We can further write $\mathbf{W}$ and $\mathbf{U}$ in matrix form:
71
+ ```math
72
+ \begin{align*}
73
+ \mathbf{T} &= \left(\mathbf{I} + \mathrm{tril}\left(\mathrm{diag}(\beta)\mathbf{K} \mathbf{K}^\top,-1\right)\right)^{-1}\mathrm{diag}\left(\beta\right)\in \mathbb{R}^{C \times C}\\
74
+ \mathbf{W} &= \mathbf{T} \mathbf{K}\in \mathbb{R}^{C \times d_k}\\
75
+ \mathbf{U} &= \mathbf{T}\mathbf{V}\in \mathbb{R}^{C \times d_v}
76
+ \end{align*}
77
+ ```
78
+
79
+ Substituting these back into the original equations yields a hardware-efficient chunkwise algorithm for DeltaNet that leverages matrix multiplications, enabling tensor core based GPU optimization:
80
+ ```math
81
+ \begin{equation}
82
+ \begin{aligned}
83
+ \mathbf{S} &= \mathbf{P}\cdot\mathbf{S}^0 + \mathbf{H} \\
84
+ &= \mathbf{S}^0 + \mathbf{K}^\top (\mathbf{U} -\mathbf{W} \mathbf{S}^0) \in \mathbb{R}^{d_k \times d_v}\\
85
+ \mathbf{O} &= \mathbf{Q} \mathbf{S}^0 + (\mathbf{Q} \mathbf{K}^{\top} \odot \mathbf{M}) \left(\mathbf{U} - \mathbf{W} \mathbf{S}^0\right) \in \mathbb{R}^{C \times d_v}
86
+ \end{aligned}
87
+ \end{equation}
88
+ ```
89
+
90
+ [^1]: https://arxiv.org/abs/2406.06484
fla/ops/delta_rule/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (34 kB). View file
 
fla/ops/delta_rule/chunk.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ from einops import rearrange
9
+
10
+ from fla.modules.l2norm import l2norm_bwd, l2norm_fwd
11
+ from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h
12
+ from fla.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o
13
+ from fla.ops.common.utils import prepare_chunk_indices
14
+ from fla.ops.delta_rule.wy_fast import bwd_prepare_wy_repr, fwd_prepare_wy_repr, fwd_recompute_w_u
15
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
16
+
17
+
18
+ def chunk_delta_rule_fwd(
19
+ q: torch.Tensor,
20
+ k: torch.Tensor,
21
+ v: torch.Tensor,
22
+ beta: torch.Tensor,
23
+ scale: float,
24
+ initial_state: torch.Tensor,
25
+ output_final_state: bool,
26
+ offsets: Optional[torch.LongTensor] = None,
27
+ indices: Optional[torch.LongTensor] = None,
28
+ head_first: bool = True,
29
+ chunk_size: int = 64
30
+ ):
31
+ T = q.shape[2] if head_first else q.shape[1]
32
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
33
+ # obtain WY representation. u is actually the new v.
34
+ w, u, A = fwd_prepare_wy_repr(
35
+ k=k,
36
+ v=v,
37
+ beta=beta,
38
+ offsets=offsets,
39
+ indices=indices,
40
+ head_first=head_first,
41
+ chunk_size=BT
42
+ )
43
+
44
+ h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
45
+ k=k,
46
+ w=w,
47
+ u=u,
48
+ g=None,
49
+ initial_state=initial_state,
50
+ output_final_state=output_final_state,
51
+ offsets=offsets,
52
+ indices=indices,
53
+ head_first=head_first,
54
+ chunk_size=BT
55
+ )
56
+ o = chunk_fwd_o(
57
+ q=q,
58
+ k=k,
59
+ v=v_new,
60
+ h=h,
61
+ g=None,
62
+ scale=scale,
63
+ offsets=offsets,
64
+ indices=indices,
65
+ head_first=head_first,
66
+ chunk_size=BT
67
+ )
68
+ return o, A, final_state
69
+
70
+
71
+ def chunk_delta_rule_bwd(
72
+ q: torch.Tensor,
73
+ k: torch.Tensor,
74
+ v: torch.Tensor,
75
+ beta: torch.Tensor,
76
+ A: torch.Tensor,
77
+ scale: float,
78
+ initial_state: torch.Tensor,
79
+ do: torch.Tensor,
80
+ dht: torch.Tensor,
81
+ offsets: Optional[torch.LongTensor] = None,
82
+ indices: Optional[torch.LongTensor] = None,
83
+ head_first: bool = True,
84
+ chunk_size: int = 64
85
+ ):
86
+ T = q.shape[2] if head_first else q.shape[1]
87
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
88
+ w, u = fwd_recompute_w_u(
89
+ k=k,
90
+ v=v,
91
+ beta=beta,
92
+ A=A,
93
+ offsets=offsets,
94
+ indices=indices,
95
+ head_first=head_first,
96
+ chunk_size=BT
97
+ )
98
+ h, v_new, _ = chunk_gated_delta_rule_fwd_h(
99
+ k=k,
100
+ w=w,
101
+ u=u,
102
+ g=None,
103
+ initial_state=initial_state,
104
+ output_final_state=False,
105
+ offsets=offsets,
106
+ indices=indices,
107
+ head_first=head_first,
108
+ chunk_size=BT
109
+ )
110
+ dv = chunk_bwd_dv_local(
111
+ q=q,
112
+ k=k,
113
+ do=do,
114
+ g=None,
115
+ dh=None,
116
+ scale=scale,
117
+ offsets=offsets,
118
+ indices=indices,
119
+ head_first=head_first,
120
+ chunk_size=BT
121
+ )
122
+ dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu(
123
+ q=q,
124
+ k=k,
125
+ w=w,
126
+ g=None,
127
+ h0=initial_state,
128
+ dht=dht,
129
+ do=do,
130
+ dv=dv,
131
+ scale=scale,
132
+ offsets=offsets,
133
+ indices=indices,
134
+ head_first=head_first,
135
+ chunk_size=BT
136
+ )
137
+ dq, dk, dw, _ = chunk_bwd_dqkwg(
138
+ q=q,
139
+ k=k,
140
+ v=v_new,
141
+ h=h,
142
+ w=w,
143
+ dv=dv,
144
+ do=do,
145
+ dh=dh,
146
+ g=None,
147
+ scale=scale,
148
+ offsets=offsets,
149
+ indices=indices,
150
+ head_first=head_first,
151
+ chunk_size=BT
152
+ )
153
+ dk2, dv, db = bwd_prepare_wy_repr(
154
+ k=k,
155
+ v=v,
156
+ beta=beta,
157
+ A=A,
158
+ dw=dw,
159
+ du=dv,
160
+ offsets=offsets,
161
+ indices=indices,
162
+ head_first=head_first,
163
+ chunk_size=BT
164
+ )
165
+ dk.add_(dk2)
166
+ return dq, dk, dv, db, dh0
167
+
168
+
169
+ class ChunkDeltaRuleFunction(torch.autograd.Function):
170
+
171
+ @staticmethod
172
+ @input_guard
173
+ @autocast_custom_fwd
174
+ def forward(
175
+ ctx,
176
+ q: torch.Tensor,
177
+ k: torch.Tensor,
178
+ v: torch.Tensor,
179
+ beta: torch.Tensor,
180
+ scale: float,
181
+ initial_state: torch.Tensor,
182
+ output_final_state: bool,
183
+ offsets: Optional[torch.LongTensor] = None,
184
+ head_first: bool = True,
185
+ use_qk_l2norm_in_kernel: bool = True
186
+ ):
187
+ T = q.shape[2] if head_first else q.shape[1]
188
+ chunk_size = min(64, max(triton.next_power_of_2(T), 16))
189
+
190
+ q_orig = q
191
+ k_orig = k
192
+
193
+ if use_qk_l2norm_in_kernel:
194
+ q = l2norm_fwd(q)
195
+ k = l2norm_fwd(k)
196
+
197
+ # 2-d indices denoting the offsets of chunks in each sequence
198
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
199
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
200
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
201
+ indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
202
+
203
+ o, A, final_state = chunk_delta_rule_fwd(
204
+ q=q,
205
+ k=k,
206
+ v=v,
207
+ beta=beta,
208
+ scale=scale,
209
+ initial_state=initial_state,
210
+ output_final_state=output_final_state,
211
+ offsets=offsets,
212
+ indices=indices,
213
+ head_first=head_first,
214
+ chunk_size=chunk_size
215
+ )
216
+ ctx.save_for_backward(q_orig, k_orig, v, beta, A, initial_state)
217
+ ctx.chunk_size = chunk_size
218
+ ctx.scale = scale
219
+ ctx.offsets = offsets
220
+ ctx.indices = indices
221
+ ctx.head_first = head_first
222
+ ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
223
+ return o.to(q.dtype), final_state
224
+
225
+ @staticmethod
226
+ @input_guard
227
+ @autocast_custom_bwd
228
+ def backward(
229
+ ctx,
230
+ do: torch.Tensor,
231
+ dht: torch.Tensor
232
+ ):
233
+ q, k, v, beta, A, initial_state = ctx.saved_tensors
234
+ use_qk_l2norm_in_kernel = ctx.use_qk_l2norm_in_kernel
235
+ if use_qk_l2norm_in_kernel:
236
+ q, q_orig = l2norm_fwd(q), q
237
+ k, k_orig = l2norm_fwd(k), k
238
+
239
+ dq, dk, dv, db, dh0 = chunk_delta_rule_bwd(
240
+ q=q,
241
+ k=k,
242
+ v=v,
243
+ beta=beta,
244
+ A=A,
245
+ scale=ctx.scale,
246
+ initial_state=initial_state,
247
+ do=do,
248
+ dht=dht,
249
+ offsets=ctx.offsets,
250
+ indices=ctx.indices,
251
+ head_first=ctx.head_first,
252
+ chunk_size=ctx.chunk_size
253
+ )
254
+ if use_qk_l2norm_in_kernel:
255
+ dq = l2norm_bwd(q_orig, dq)
256
+ dk = l2norm_bwd(k_orig, dk)
257
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), db.to(beta.dtype), None, dh0, None, None, None, None, None, None
258
+
259
+
260
+ @torch.compiler.disable
261
+ def chunk_delta_rule(
262
+ q: torch.Tensor,
263
+ k: torch.Tensor,
264
+ v: torch.Tensor,
265
+ beta: torch.Tensor,
266
+ scale: float = None,
267
+ initial_state: torch.Tensor = None,
268
+ output_final_state: bool = False,
269
+ cu_seqlens: Optional[torch.LongTensor] = None,
270
+ head_first: bool = False,
271
+ use_qk_l2norm_in_kernel: bool = False
272
+ ):
273
+ r"""
274
+ Args:
275
+ q (torch.Tensor):
276
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
277
+ k (torch.Tensor):
278
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
279
+ v (torch.Tensor):
280
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
281
+ beta (torch.Tensor):
282
+ betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
283
+ scale (Optional[int]):
284
+ Scale factor for the RetNet attention scores.
285
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
286
+ initial_state (Optional[torch.Tensor]):
287
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
288
+ For equal-length input sequences, `N` equals the batch size `B`.
289
+ Default: `None`.
290
+ output_final_state (Optional[bool]):
291
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
292
+ cu_seqlens (torch.LongTensor):
293
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
294
+ consistent with the FlashAttention API.
295
+ head_first (Optional[bool]):
296
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
297
+ Default: `False`.
298
+ use_qk_l2norm_in_kernel (Optional[bool]):
299
+ Whether to use qk l2norm within the kernel for saving GPU memory.
300
+ Default: `False`.
301
+
302
+ Returns:
303
+ o (torch.Tensor):
304
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
305
+ final_state (torch.Tensor):
306
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
307
+
308
+ Examples::
309
+ >>> import torch
310
+ >>> import torch.nn.functional as F
311
+ >>> from einops import rearrange
312
+ >>> from fla.ops.delta_rule import chunk_delta_rule
313
+ # inputs with equal lengths
314
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
315
+ >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
316
+ >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
317
+ >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
318
+ >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
319
+ >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
320
+ >>> o, ht = chunk_delta_rule(
321
+ q, k, v, beta,
322
+ initial_state=h0,
323
+ output_final_state=True
324
+ )
325
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
326
+ >>> q, k, v, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta))
327
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
328
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
329
+ >>> o_var, ht_var = chunk_delta_rule(
330
+ q, k, v, beta,
331
+ initial_state=h0,
332
+ output_final_state=True,
333
+ cu_seqlens=cu_seqlens
334
+ )
335
+ """
336
+ assert q.dtype == k.dtype == v.dtype
337
+ assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16."
338
+ assert len(beta.shape) == 3, "beta must be of shape (batch size, num of head, seq len)."
339
+
340
+ if cu_seqlens is not None:
341
+ if q.shape[0] != 1:
342
+ raise ValueError(
343
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
344
+ f"Please flatten variable-length inputs before processing."
345
+ )
346
+ if head_first:
347
+ raise RuntimeError(
348
+ "Sequences with variable lengths are not supported for head-first mode"
349
+ )
350
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
351
+ raise ValueError(
352
+ f"The number of initial states is expected to be equal to the number of input sequences, "
353
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
354
+ )
355
+ if head_first:
356
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
357
+ beta = rearrange(beta, 'b h t -> b t h')
358
+ scale = k.shape[-1] ** -0.5 if scale is None else scale
359
+ o, final_state = ChunkDeltaRuleFunction.apply(
360
+ q,
361
+ k,
362
+ v,
363
+ beta,
364
+ scale,
365
+ initial_state,
366
+ output_final_state,
367
+ cu_seqlens,
368
+ False,
369
+ use_qk_l2norm_in_kernel
370
+ )
371
+ if head_first:
372
+ o = rearrange(o, 'b t h v -> b h t v')
373
+ return o, final_state
fla/ops/delta_rule/fused_recurrent.py ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange
10
+
11
+ from fla.modules.l2norm import l2norm_bwd, l2norm_fwd
12
+ from fla.utils import input_guard
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.jit(do_not_specialize=['T'])
21
+ def fused_recurrent_delta_rule_fwd_kernel(
22
+ q,
23
+ k,
24
+ v,
25
+ u,
26
+ beta,
27
+ o,
28
+ h0,
29
+ ht,
30
+ offsets,
31
+ scale,
32
+ T,
33
+ B: tl.constexpr,
34
+ H: tl.constexpr,
35
+ K: tl.constexpr,
36
+ V: tl.constexpr,
37
+ BK: tl.constexpr,
38
+ BV: tl.constexpr,
39
+ USE_INITIAL_STATE: tl.constexpr,
40
+ STORE_FINAL_STATE: tl.constexpr,
41
+ IS_BETA_HEADWISE: tl.constexpr,
42
+ USE_OFFSETS: tl.constexpr,
43
+ HEAD_FIRST: tl.constexpr
44
+ ):
45
+ i_v, i_k, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
46
+ i_n, i_h = i_nh // H, i_nh % H
47
+ if USE_OFFSETS:
48
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
49
+ all = T
50
+ T = eos - bos
51
+ else:
52
+ bos, eos = i_n * T, i_n * T + T
53
+ all = B * T
54
+
55
+ if HEAD_FIRST:
56
+ p_q = q + i_nh * T*K + i_k * BK + tl.arange(0, BK)
57
+ p_k = k + i_nh * T*K + i_k * BK + tl.arange(0, BK)
58
+ p_v = v + i_nh * T*V + i_v * BV + tl.arange(0, BV)
59
+ p_u = u + i_nh * T*V + i_v * BV + tl.arange(0, BV)
60
+ if IS_BETA_HEADWISE:
61
+ p_beta = beta + i_nh * T*V + i_v * BV + tl.arange(0, BV)
62
+ else:
63
+ p_beta = beta + i_nh * T
64
+ p_o = o + (i_k * B*H + i_nh) * T*V + i_v * BV + tl.arange(0, BV)
65
+ else:
66
+ p_q = q + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
67
+ p_k = k + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
68
+ p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
69
+ p_u = u + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
70
+ if IS_BETA_HEADWISE:
71
+ p_beta = beta + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
72
+ else:
73
+ p_beta = beta + bos * H + i_h
74
+ p_o = o + ((i_k * all + bos) * H + i_h) * V + i_v * BV + tl.arange(0, BV)
75
+
76
+ mask_k = (i_k * BK + tl.arange(0, BK)) < K
77
+ mask_v = (i_v * BV + tl.arange(0, BV)) < V
78
+ mask_h = mask_k[None, :] & mask_v[:, None]
79
+
80
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
81
+ if USE_INITIAL_STATE:
82
+ p_h0 = h0 + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
83
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
84
+
85
+ for _ in range(0, T):
86
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
87
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
88
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
89
+ b_v_minus = tl.sum(b_h * b_k[None, :], axis=1)
90
+ b_v -= b_v_minus
91
+ if IS_BETA_HEADWISE:
92
+ b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
93
+ else:
94
+ b_beta = tl.load(p_beta).to(tl.float32)
95
+ tl.store(p_u, b_v.to(p_v.dtype.element_ty), mask=mask_v)
96
+ b_v *= b_beta
97
+ b_h += b_k[None, :] * b_v[:, None]
98
+ b_o = b_h * b_q[None, :]
99
+ b_o = tl.sum(b_o, axis=1)
100
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
101
+
102
+ p_q += K if HEAD_FIRST else H*K
103
+ p_k += K if HEAD_FIRST else H*K
104
+ p_o += V if HEAD_FIRST else H*V
105
+ p_v += V if HEAD_FIRST else H*V
106
+ p_u += V if HEAD_FIRST else H*V
107
+ p_beta += (1 if HEAD_FIRST else H) * (V if IS_BETA_HEADWISE else 1)
108
+
109
+ if STORE_FINAL_STATE:
110
+ p_ht = ht + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
111
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
112
+
113
+
114
+ @triton.heuristics({
115
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
116
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
117
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
118
+ })
119
+ @triton.jit(do_not_specialize=['T'])
120
+ def fused_recurrent_delta_rule_bwd_kernel(
121
+ q,
122
+ k,
123
+ v,
124
+ beta,
125
+ h0,
126
+ dh0,
127
+ dht,
128
+ do,
129
+ dq,
130
+ dk,
131
+ dv,
132
+ db,
133
+ offsets,
134
+ scale,
135
+ B: tl.constexpr,
136
+ T,
137
+ H: tl.constexpr,
138
+ K: tl.constexpr,
139
+ V: tl.constexpr,
140
+ BK: tl.constexpr,
141
+ BV: tl.constexpr,
142
+ NK: tl.constexpr,
143
+ IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar
144
+ USE_INITIAL_STATE: tl.constexpr, # whether to use dh0
145
+ USE_FINAL_STATE_GRADIENT: tl.constexpr, # whether to use dht
146
+ USE_OFFSETS: tl.constexpr,
147
+ HEAD_FIRST: tl.constexpr
148
+ ):
149
+ i_v, i_k, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
150
+ i_n, i_h = i_nh // H, i_nh % H
151
+ if USE_OFFSETS:
152
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
153
+ all = T
154
+ T = eos - bos
155
+ else:
156
+ bos, eos = i_n * T, i_n * T + T
157
+ all = B * T
158
+
159
+ mask_k = i_k * BK + tl.arange(0, BK) < K
160
+ mask_v = i_v * BV + tl.arange(0, BV) < V
161
+
162
+ if HEAD_FIRST:
163
+ p_q = q + i_nh * T*K + i_k * BK + tl.arange(0, BK) + (T - 1) * K
164
+ p_k = k + i_nh * T*K + i_k * BK + tl.arange(0, BK) + (T - 1) * K
165
+ p_v = v + i_nh * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V
166
+ p_do = do + i_nh * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V
167
+ p_dk = dk + (i_v * B*H + i_nh) * T*K + i_k * BK + tl.arange(0, BK) + (T - 1) * K
168
+ p_dv = dv + (i_k * B*H + i_nh) * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V
169
+ if IS_BETA_HEADWISE:
170
+ p_beta = beta + i_nh * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V
171
+ p_dbeta = db + (i_v * NK*B*H + i_k * B*H + i_nh) * T*V + tl.arange(0, BV) + (T - 1) * V
172
+ else:
173
+ p_beta = beta + i_nh * T + T - 1
174
+ p_dbeta = db + (i_v * B*H + i_nh) * T + T - 1
175
+ else:
176
+ p_q = q + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK) + (T - 1) * H*K
177
+ p_k = k + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK) + (T - 1) * H*K
178
+ p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + (T - 1) * H*V
179
+ p_do = do + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + (T - 1) * H*V
180
+ p_dk = dk + ((i_v * all + bos) * H + i_h) * K + i_k * BK + tl.arange(0, BK) + (T - 1) * H*K
181
+ p_dv = dv + ((i_k * all + bos) * H + i_h) * V + i_v * BV + tl.arange(0, BV) + (T - 1) * H*V
182
+ if IS_BETA_HEADWISE:
183
+ p_beta = beta + (bos + T - 1) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
184
+ p_dbeta = db + ((i_v * NK + i_k) * all + bos + T - 1) * H*V + i_h * V + tl.arange(0, BV)
185
+ else:
186
+ p_beta = beta + (bos + T - 1) * H + i_h
187
+ p_dbeta = db + (i_v * all + bos + T - 1) * H + i_h
188
+
189
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
190
+ if USE_FINAL_STATE_GRADIENT:
191
+ p_ht = dht + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
192
+ b_dh += tl.load(p_ht, mask=mask_k[:, None] & mask_v[None, :], other=0).to(tl.float32)
193
+
194
+ for _ in range(T):
195
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
196
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
197
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
198
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
199
+ if IS_BETA_HEADWISE:
200
+ b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
201
+ else:
202
+ b_beta = tl.load(p_beta).to(tl.float32)
203
+ b_dh += b_q[:, None] * b_do[None, :]
204
+ b_dk = tl.sum(b_dh * (b_v * b_beta)[None, :], axis=1)
205
+ b_dv = tl.sum(b_dh * b_k[:, None], axis=0)
206
+
207
+ b_db = b_dv * b_v if IS_BETA_HEADWISE else tl.sum(b_dv * b_v)
208
+ b_dv = b_dv * b_beta
209
+
210
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k)
211
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_v)
212
+ if IS_BETA_HEADWISE:
213
+ tl.store(p_dbeta, b_db.to(p_dbeta.dtype.element_ty), mask=mask_v)
214
+ else:
215
+ tl.store(p_dbeta, b_db.to(p_dbeta.dtype.element_ty))
216
+
217
+ b_dh -= b_k[:, None] * b_dv[None, :]
218
+
219
+ p_q -= K if HEAD_FIRST else H*K
220
+ p_k -= K if HEAD_FIRST else H*K
221
+ p_v -= V if HEAD_FIRST else H*V
222
+ p_do -= V if HEAD_FIRST else H*V
223
+ p_dk -= K if HEAD_FIRST else H*K
224
+ p_dv -= V if HEAD_FIRST else H*V
225
+ p_dbeta -= (1 if HEAD_FIRST else H) * (V if IS_BETA_HEADWISE else 1)
226
+ p_beta -= (1 if HEAD_FIRST else H) * (V if IS_BETA_HEADWISE else 1)
227
+
228
+ if USE_INITIAL_STATE:
229
+ p_dh0 = dh0 + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
230
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_k[:, None] & mask_v[None, :])
231
+
232
+ tl.debug_barrier()
233
+
234
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
235
+
236
+ if HEAD_FIRST:
237
+ p_q = q + i_nh * T*K + i_k * BK + tl.arange(0, BK)
238
+ p_k = k + i_nh * T*K + i_k * BK + tl.arange(0, BK)
239
+ p_v = v + i_nh * T*V + i_v * BV + tl.arange(0, BV)
240
+ if IS_BETA_HEADWISE:
241
+ p_beta = beta + i_nh * T*V + i_v * BV + tl.arange(0, BV)
242
+ else:
243
+ p_beta = beta + i_nh * T
244
+ p_do = do + i_nh * T*V + i_v * BV + tl.arange(0, BV)
245
+ p_dq = dq + (i_v * B*H + i_nh) * T*K + i_k * BK + tl.arange(0, BK)
246
+ p_dk = dk + (i_v * B*H + i_nh) * T*K + i_k * BK + tl.arange(0, BK)
247
+ p_dv = dv + (i_k * B*H + i_nh) * T*V + i_v * BV + tl.arange(0, BV)
248
+ else:
249
+ p_q = q + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
250
+ p_k = k + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
251
+ p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
252
+ if IS_BETA_HEADWISE:
253
+ p_beta = beta + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
254
+ else:
255
+ p_beta = beta + bos * H + i_h
256
+ p_do = do + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
257
+ p_dq = dq + ((i_v * all + bos) * H + i_h) * K + i_k * BK + tl.arange(0, BK)
258
+ p_dk = dk + ((i_v * all + bos) * H + i_h) * K + i_k * BK + tl.arange(0, BK)
259
+ p_dv = dv + ((i_k * all + bos) * H + i_h) * V + i_v * BV + tl.arange(0, BV)
260
+
261
+ if USE_INITIAL_STATE:
262
+ mask_h = mask_k[:, None] & mask_v[None, :]
263
+ p_h0 = h0 + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
264
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
265
+
266
+ for _ in range(0, T):
267
+ b_dk = tl.load(p_dk, mask=mask_k, other=0).to(tl.float32)
268
+ b_dv = tl.load(p_dv, mask=mask_v, other=0).to(tl.float32)
269
+ b_dk -= tl.sum(b_dv[None, :] * b_h, axis=1)
270
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k)
271
+
272
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
273
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
274
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
275
+ if IS_BETA_HEADWISE:
276
+ b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
277
+ else:
278
+ b_beta = tl.load(p_beta).to(tl.float32)
279
+ b_v *= b_beta
280
+
281
+ b_h += b_k[:, None] * b_v[None, :]
282
+ b_dq = b_h * b_do[None, :]
283
+ d_q = tl.sum(b_dq, axis=1) * scale
284
+ tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_k)
285
+
286
+ p_k += K if HEAD_FIRST else H*K
287
+ p_v += V if HEAD_FIRST else H*V
288
+ p_do += V if HEAD_FIRST else H*V
289
+ p_dq += K if HEAD_FIRST else H*K
290
+ p_dk += K if HEAD_FIRST else H*K
291
+ p_dv += V if HEAD_FIRST else H*V
292
+ p_beta += (1 if HEAD_FIRST else H) * (V if IS_BETA_HEADWISE else 1)
293
+
294
+
295
+ def fused_recurrent_delta_rule_fwd(
296
+ q: torch.Tensor,
297
+ k: torch.Tensor,
298
+ v: torch.Tensor,
299
+ beta: torch.Tensor,
300
+ scale: float,
301
+ initial_state: torch.Tensor,
302
+ output_final_state: bool,
303
+ offsets: Optional[torch.LongTensor] = None,
304
+ head_first: bool = True
305
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
306
+ if head_first:
307
+ B, H, T, K, V = *k.shape, v.shape[-1]
308
+ else:
309
+ B, T, H, K, V = *k.shape, v.shape[-1]
310
+ N = B if offsets is None else len(offsets) - 1
311
+ BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
312
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
313
+ assert NK == 1, "NK > 1 is not supported yet"
314
+ num_stages = 1
315
+ num_warps = 1
316
+
317
+ o = q.new_empty(NK, *v.shape)
318
+ if output_final_state:
319
+ final_state = q.new_empty(N, H, K, V, dtype=torch.float32)
320
+ else:
321
+ final_state = None
322
+
323
+ grid = (NV, NK, N * H)
324
+ u = torch.empty_like(v)
325
+ fused_recurrent_delta_rule_fwd_kernel[grid](
326
+ q,
327
+ k,
328
+ v,
329
+ u,
330
+ beta,
331
+ o,
332
+ initial_state,
333
+ final_state,
334
+ offsets,
335
+ scale,
336
+ T=T,
337
+ B=B,
338
+ H=H,
339
+ K=K,
340
+ V=V,
341
+ BK=BK,
342
+ BV=BV,
343
+ IS_BETA_HEADWISE=beta.ndim == v.ndim,
344
+ HEAD_FIRST=head_first,
345
+ num_warps=num_warps,
346
+ num_stages=num_stages,
347
+ )
348
+ o = o.squeeze(0)
349
+ return o, u, final_state
350
+
351
+
352
+ def fused_recurrent_delta_rule_bwd(
353
+ q: torch.Tensor,
354
+ k: torch.Tensor,
355
+ v: torch.Tensor,
356
+ beta: torch.Tensor,
357
+ dht: torch.Tensor,
358
+ do: torch.Tensor,
359
+ scale: float,
360
+ initial_state: torch.Tensor,
361
+ offsets: Optional[torch.LongTensor] = None,
362
+ head_first: bool = True
363
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
364
+ if head_first:
365
+ B, H, T, K, V = *k.shape, v.shape[-1]
366
+ else:
367
+ B, T, H, K, V = *k.shape, v.shape[-1]
368
+ N = B if offsets is None else len(offsets) - 1
369
+ BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)
370
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
371
+ assert NK == 1, "NK > 1 is not supported yet"
372
+ num_stages = 1
373
+ num_warps = 2
374
+
375
+ beta_vector = beta.ndim == v.ndim
376
+
377
+ dq = q.new_empty(NV, *q.shape)
378
+ dk = q.new_empty(NV, *k.shape)
379
+ dv = q.new_empty(NK, *v.shape)
380
+ if beta_vector:
381
+ db = q.new_empty(NV, NK, B, H, T, V) if head_first else q.new_empty(NV, NK, B, T, H, V)
382
+ else:
383
+ db = q.new_empty(NV, B, H, T) if head_first else q.new_empty(NV, B, T, H)
384
+ grid = (NV, NK, N * H)
385
+
386
+ if initial_state is not None and initial_state.requires_grad:
387
+ dh0 = torch.empty_like(initial_state, dtype=torch.float32)
388
+ else:
389
+ dh0 = None
390
+
391
+ fused_recurrent_delta_rule_bwd_kernel[grid](
392
+ q,
393
+ k,
394
+ v,
395
+ beta,
396
+ initial_state,
397
+ dh0,
398
+ dht,
399
+ do,
400
+ dq,
401
+ dk,
402
+ dv,
403
+ db,
404
+ offsets,
405
+ scale,
406
+ T=T,
407
+ B=B,
408
+ H=H,
409
+ K=K,
410
+ V=V,
411
+ BK=BK,
412
+ BV=BV,
413
+ NK=NK,
414
+ IS_BETA_HEADWISE=beta_vector,
415
+ HEAD_FIRST=head_first,
416
+ num_warps=num_warps,
417
+ num_stages=num_stages
418
+ )
419
+ dq = dq.sum(0)
420
+ dk = dk.sum(0)
421
+ dv = dv.sum(0)
422
+ db = db.sum((0, 1)) if beta_vector else db.sum(0)
423
+
424
+ return dq, dk, dv, db, dh0
425
+
426
+
427
+ class FusedRecurrentFunction(torch.autograd.Function):
428
+
429
+ @staticmethod
430
+ @input_guard
431
+ def forward(
432
+ ctx,
433
+ q: torch.Tensor,
434
+ k: torch.Tensor,
435
+ v: torch.Tensor,
436
+ beta: torch.Tensor,
437
+ scale: float,
438
+ initial_state: torch.Tensor,
439
+ output_final_state: bool,
440
+ offsets: Optional[torch.LongTensor] = None,
441
+ head_first: bool = True,
442
+ use_qk_l2norm_in_kernel: bool = False
443
+ ):
444
+ q_orig = q
445
+ k_orig = k
446
+
447
+ if use_qk_l2norm_in_kernel:
448
+ q = l2norm_fwd(q)
449
+ k = l2norm_fwd(k)
450
+
451
+ o, u, final_state = fused_recurrent_delta_rule_fwd(
452
+ q=q,
453
+ k=k,
454
+ v=v,
455
+ beta=beta,
456
+ scale=scale,
457
+ initial_state=initial_state,
458
+ output_final_state=output_final_state,
459
+ offsets=offsets,
460
+ head_first=head_first
461
+ )
462
+
463
+ ctx.save_for_backward(q_orig, k_orig, u, beta, initial_state)
464
+ ctx.scale = scale
465
+ ctx.offsets = offsets
466
+ ctx.head_first = head_first
467
+ ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
468
+ return o, final_state
469
+
470
+ @staticmethod
471
+ @input_guard
472
+ def backward(ctx, do, dht):
473
+ q, k, v, beta, initial_state = ctx.saved_tensors
474
+ if ctx.use_qk_l2norm_in_kernel:
475
+ q, q_orig = l2norm_fwd(q), q
476
+ k, k_orig = l2norm_fwd(k), k
477
+ dq, dk, dv, db, dh0 = fused_recurrent_delta_rule_bwd(
478
+ q=q,
479
+ k=k,
480
+ v=v,
481
+ beta=beta,
482
+ dht=dht,
483
+ do=do,
484
+ scale=ctx.scale,
485
+ initial_state=initial_state,
486
+ offsets=ctx.offsets,
487
+ head_first=ctx.head_first
488
+ )
489
+ if ctx.use_qk_l2norm_in_kernel:
490
+ dq, dk = l2norm_bwd(q_orig, dq), l2norm_bwd(k_orig, dk)
491
+ return dq.to(q), dk.to(k), dv.to(v), db.to(beta), None, dh0, None, None, None, None
492
+
493
+
494
+ @torch.compiler.disable
495
+ def fused_recurrent_delta_rule(
496
+ q: torch.Tensor,
497
+ k: torch.Tensor,
498
+ v: torch.Tensor,
499
+ beta: torch.Tensor = None,
500
+ scale: float = None,
501
+ initial_state: torch.Tensor = None,
502
+ output_final_state: bool = False,
503
+ cu_seqlens: Optional[torch.LongTensor] = None,
504
+ head_first: bool = True,
505
+ use_qk_l2norm_in_kernel: bool = False
506
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
507
+ r"""
508
+ Args:
509
+ q (torch.Tensor):
510
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
511
+ k (torch.Tensor):
512
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
513
+ v (torch.Tensor):
514
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
515
+ beta (torch.Tensor):
516
+ betas of shape `[B, T, H]` if `head_first=False` else `(B, H, T)`.
517
+ scale (Optional[int]):
518
+ Scale factor for the RetNet attention scores.
519
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
520
+ initial_state (Optional[torch.Tensor]):
521
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
522
+ For equal-length input sequences, `N` equals the batch size `B`.
523
+ Default: `None`.
524
+ output_final_state (Optional[bool]):
525
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
526
+ cu_seqlens (torch.LongTensor):
527
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
528
+ consistent with the FlashAttention API.
529
+ head_first (Optional[bool]):
530
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
531
+ Default: `False`.
532
+
533
+ Returns:
534
+ o (torch.Tensor):
535
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
536
+ final_state (torch.Tensor):
537
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
538
+
539
+ Examples::
540
+ >>> import torch
541
+ >>> import torch.nn.functional as F
542
+ >>> from einops import rearrange
543
+ >>> from fla.ops.delta_rule import fused_recurrent_delta_rule
544
+ # inputs with equal lengths
545
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
546
+ >>> q = torch.randn(B, T, H, K, device='cuda')
547
+ >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
548
+ >>> v = torch.randn(B, T, H, V, device='cuda')
549
+ >>> beta = torch.rand(B, T, H, device='cuda').sigmoid()
550
+ >>> h0 = torch.randn(B, H, K, V, device='cuda')
551
+ >>> o, ht = fused_recurrent_delta_rule(
552
+ q, k, v, beta,
553
+ initial_state=h0,
554
+ output_final_state=True
555
+ )
556
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
557
+ >>> q, k, v, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta))
558
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
559
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
560
+ >>> o_var, ht_var = fused_recurrent_delta_rule(
561
+ q, k, v, beta,
562
+ initial_state=h0,
563
+ output_final_state=True,
564
+ cu_seqlens=cu_seqlens
565
+ )
566
+ >>> assert o.allclose(o_var.view(o.shape))
567
+ >>> assert ht.allclose(ht_var)
568
+ """
569
+ if cu_seqlens is not None:
570
+ if q.shape[0] != 1:
571
+ raise ValueError(
572
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
573
+ f"Please flatten variable-length inputs before processing."
574
+ )
575
+ if head_first:
576
+ raise RuntimeError(
577
+ "Sequences with variable lengths are not supported for head-first mode"
578
+ )
579
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
580
+ raise ValueError(
581
+ f"The number of initial states is expected to be equal to the number of input sequences, "
582
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
583
+ )
584
+ if scale is None:
585
+ scale = k.shape[-1] ** -0.5
586
+ else:
587
+ assert scale > 0, "scale must be positive"
588
+ if beta is None:
589
+ beta = torch.ones_like(q[..., 0])
590
+ if head_first:
591
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
592
+ beta = rearrange(beta, 'b h t -> b t h')
593
+ o, final_state = FusedRecurrentFunction.apply(
594
+ q,
595
+ k,
596
+ v,
597
+ beta,
598
+ scale,
599
+ initial_state,
600
+ output_final_state,
601
+ cu_seqlens,
602
+ False,
603
+ use_qk_l2norm_in_kernel
604
+ )
605
+ if head_first:
606
+ o = rearrange(o, 'b t h v -> b h t v')
607
+ return o, final_state
fla/ops/forgetting_attn/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .parallel import parallel_forgetting_attn
4
+
5
+ __all__ = [
6
+ 'parallel_forgetting_attn'
7
+ ]
fla/ops/gated_delta_rule/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (14.4 kB). View file
 
fla/ops/gated_delta_rule/chunk.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ from einops import rearrange
9
+
10
+ from fla.modules.l2norm import l2norm_bwd, l2norm_fwd
11
+ from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h
12
+ from fla.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o
13
+ from fla.ops.gated_delta_rule.wy_fast import bwd_prepare_wy_repr, fwd_prepare_wy_repr, fwd_recompute_w_u
14
+ from fla.ops.utils import chunk_local_cumsum
15
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
16
+
17
+
18
+ def chunk_gated_delta_rule_fwd(
19
+ q: torch.Tensor,
20
+ k: torch.Tensor,
21
+ v: torch.Tensor,
22
+ g: torch.Tensor,
23
+ beta: torch.Tensor,
24
+ scale: float,
25
+ initial_state: torch.Tensor,
26
+ output_final_state: bool,
27
+ offsets: Optional[torch.LongTensor] = None,
28
+ indices: Optional[torch.LongTensor] = None,
29
+ head_first: bool = True,
30
+ chunk_size: int = 64
31
+ ):
32
+ g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first)
33
+ # obtain WY representation. u is actually the new v.
34
+ w, u, Aw, Au = fwd_prepare_wy_repr(
35
+ k=k,
36
+ v=v,
37
+ beta=beta,
38
+ g=g,
39
+ offsets=offsets,
40
+ indices=indices,
41
+ head_first=head_first,
42
+ chunk_size=chunk_size
43
+ )
44
+
45
+ h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
46
+ k=k,
47
+ w=w,
48
+ u=u,
49
+ g=g,
50
+ initial_state=initial_state,
51
+ output_final_state=output_final_state,
52
+ offsets=offsets,
53
+ indices=indices,
54
+ head_first=head_first,
55
+ chunk_size=chunk_size
56
+ )
57
+
58
+ # obtain output
59
+ o = chunk_fwd_o(
60
+ q=q,
61
+ k=k,
62
+ v=v_new,
63
+ h=h,
64
+ g=g,
65
+ scale=scale,
66
+ offsets=offsets,
67
+ indices=indices,
68
+ head_first=head_first,
69
+ chunk_size=chunk_size
70
+ )
71
+ return g, o, Aw, Au, final_state
72
+
73
+
74
+ def chunk_gated_delta_rule_bwd(
75
+ q: torch.Tensor,
76
+ k: torch.Tensor,
77
+ v: torch.Tensor,
78
+ g: torch.Tensor,
79
+ beta: torch.Tensor,
80
+ Aw: torch.Tensor,
81
+ Au: torch.Tensor,
82
+ scale: float,
83
+ initial_state: torch.Tensor,
84
+ do: torch.Tensor,
85
+ dht: torch.Tensor,
86
+ offsets: Optional[torch.LongTensor] = None,
87
+ indices: Optional[torch.LongTensor] = None,
88
+ head_first: bool = True,
89
+ chunk_size: int = 64
90
+ ):
91
+ T = q.shape[2] if head_first else q.shape[1]
92
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
93
+ w, u = fwd_recompute_w_u(
94
+ k=k,
95
+ v=v,
96
+ beta=beta,
97
+ Aw=Aw,
98
+ Au=Au,
99
+ offsets=offsets,
100
+ indices=indices,
101
+ head_first=head_first,
102
+ chunk_size=BT
103
+ )
104
+ h, v_new, _ = chunk_gated_delta_rule_fwd_h(
105
+ k=k,
106
+ w=w,
107
+ u=u,
108
+ g=g,
109
+ initial_state=initial_state,
110
+ output_final_state=False,
111
+ offsets=offsets,
112
+ indices=indices,
113
+ head_first=head_first,
114
+ chunk_size=BT
115
+ )
116
+ dv = chunk_bwd_dv_local(
117
+ q=q,
118
+ k=k,
119
+ g=g,
120
+ do=do,
121
+ dh=None,
122
+ scale=scale,
123
+ offsets=offsets,
124
+ indices=indices,
125
+ head_first=head_first,
126
+ chunk_size=BT
127
+ )
128
+ dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu(
129
+ q=q,
130
+ k=k,
131
+ w=w,
132
+ g=g,
133
+ h0=initial_state,
134
+ dht=dht,
135
+ do=do,
136
+ dv=dv,
137
+ scale=scale,
138
+ offsets=offsets,
139
+ indices=indices,
140
+ head_first=head_first,
141
+ chunk_size=BT
142
+ )
143
+ dq, dk, dw, dg = chunk_bwd_dqkwg(
144
+ q=q,
145
+ k=k,
146
+ v=v_new,
147
+ w=w,
148
+ g=g,
149
+ h=h,
150
+ dv=dv,
151
+ do=do,
152
+ dh=dh,
153
+ scale=scale,
154
+ offsets=offsets,
155
+ indices=indices,
156
+ head_first=head_first,
157
+ chunk_size=BT
158
+ )
159
+ dk2, dv, db, dg2 = bwd_prepare_wy_repr(
160
+ k=k,
161
+ v=v,
162
+ beta=beta,
163
+ g=g,
164
+ Aw=Aw,
165
+ Au=Au,
166
+ dw=dw,
167
+ du=dv,
168
+ offsets=offsets,
169
+ indices=indices,
170
+ head_first=head_first,
171
+ chunk_size=BT
172
+ )
173
+ dk.add_(dk2)
174
+ dg.add_(dg2)
175
+ assert dg.dtype == torch.float32, "dg should be fp32"
176
+ dg = chunk_local_cumsum(dg, chunk_size, reverse=True, offsets=offsets, indices=indices, head_first=head_first)
177
+ return dq, dk, dv, db, dg, dh0
178
+
179
+
180
+ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
181
+
182
+ @staticmethod
183
+ @input_guard
184
+ @autocast_custom_fwd
185
+ def forward(
186
+ ctx,
187
+ q: torch.Tensor,
188
+ k: torch.Tensor,
189
+ v: torch.Tensor,
190
+ g: torch.Tensor,
191
+ beta: torch.Tensor,
192
+ scale: float,
193
+ initial_state: torch.Tensor,
194
+ output_final_state: bool,
195
+ offsets: Optional[torch.LongTensor] = None,
196
+ head_first: bool = True,
197
+ use_qk_l2norm_in_kernel: bool = False
198
+ ):
199
+ chunk_size = 64
200
+ q_orig = q
201
+ k_orig = k
202
+
203
+ if use_qk_l2norm_in_kernel:
204
+ q = l2norm_fwd(q)
205
+ k = l2norm_fwd(k)
206
+
207
+ # 2-d indices denoting the offsets of chunks in each sequence
208
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
209
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
210
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
211
+ indices = None
212
+ if offsets is not None:
213
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()])
214
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
215
+
216
+ g, o, Aw, Au, final_state = chunk_gated_delta_rule_fwd(
217
+ q=q,
218
+ k=k,
219
+ v=v,
220
+ g=g,
221
+ beta=beta,
222
+ scale=scale,
223
+ initial_state=initial_state,
224
+ output_final_state=output_final_state,
225
+ offsets=offsets,
226
+ indices=indices,
227
+ head_first=head_first,
228
+ chunk_size=chunk_size,
229
+ )
230
+ ctx.save_for_backward(q_orig, k_orig, v, g, beta, Aw, Au, initial_state, offsets, indices)
231
+ ctx.chunk_size = chunk_size
232
+ ctx.scale = scale
233
+ ctx.head_first = head_first
234
+ ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
235
+ return o.to(q.dtype), final_state
236
+
237
+ @staticmethod
238
+ @input_guard
239
+ @autocast_custom_bwd
240
+ def backward(
241
+ ctx,
242
+ do: torch.Tensor,
243
+ dht: torch.Tensor
244
+ ):
245
+ q, k, v, g, beta, Aw, Au, initial_state, offsets, indices = ctx.saved_tensors
246
+ if ctx.use_qk_l2norm_in_kernel:
247
+ q, q_orig = l2norm_fwd(q), q
248
+ k, k_orig = l2norm_fwd(k), k
249
+ dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd(
250
+ q=q,
251
+ k=k,
252
+ v=v,
253
+ g=g,
254
+ beta=beta,
255
+ Aw=Aw,
256
+ Au=Au,
257
+ scale=ctx.scale,
258
+ initial_state=initial_state,
259
+ do=do,
260
+ dht=dht,
261
+ offsets=offsets,
262
+ indices=indices,
263
+ head_first=ctx.head_first,
264
+ chunk_size=ctx.chunk_size
265
+ )
266
+ if ctx.use_qk_l2norm_in_kernel:
267
+ dq = l2norm_bwd(q_orig, dq)
268
+ dk = l2norm_bwd(k_orig, dk)
269
+ return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None
270
+
271
+
272
+ @torch.compiler.disable
273
+ def chunk_gated_delta_rule(
274
+ q: torch.Tensor,
275
+ k: torch.Tensor,
276
+ v: torch.Tensor,
277
+ g: torch.Tensor,
278
+ beta: torch.Tensor,
279
+ scale: float = None,
280
+ initial_state: torch.Tensor = None,
281
+ output_final_state: bool = False,
282
+ cu_seqlens: Optional[torch.LongTensor] = None,
283
+ head_first: bool = False,
284
+ use_qk_l2norm_in_kernel: bool = False
285
+ ):
286
+ r"""
287
+ Args:
288
+ q (torch.Tensor):
289
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
290
+ k (torch.Tensor):
291
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
292
+ v (torch.Tensor):
293
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
294
+ g (torch.Tensor):
295
+ (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
296
+ beta (torch.Tensor):
297
+ betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
298
+ scale (Optional[int]):
299
+ Scale factor for the RetNet attention scores.
300
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
301
+ initial_state (Optional[torch.Tensor]):
302
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
303
+ For equal-length input sequences, `N` equals the batch size `B`.
304
+ Default: `None`.
305
+ output_final_state (Optional[bool]):
306
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
307
+ cu_seqlens (torch.LongTensor):
308
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
309
+ consistent with the FlashAttention API.
310
+ head_first (Optional[bool]):
311
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
312
+ Default: `False`.
313
+
314
+ Returns:
315
+ o (torch.Tensor):
316
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
317
+ final_state (torch.Tensor):
318
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
319
+
320
+ Examples::
321
+ >>> import torch
322
+ >>> import torch.nn.functional as F
323
+ >>> from einops import rearrange
324
+ >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
325
+ # inputs with equal lengths
326
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
327
+ >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
328
+ >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
329
+ >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
330
+ >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
331
+ >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
332
+ >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
333
+ >>> o, ht = chunk_gated_delta_rule(
334
+ q, k, v, g, beta,
335
+ initial_state=h0,
336
+ output_final_state=True,
337
+ head_first=False
338
+ )
339
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
340
+ >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
341
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
342
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
343
+ >>> o_var, ht_var = chunk_gated_delta_rule(
344
+ q, k, v, g, beta,
345
+ initial_state=h0,
346
+ output_final_state=True,
347
+ cu_seqlens=cu_seqlens,
348
+ head_first=False
349
+ )
350
+ """
351
+ assert q.dtype == k.dtype == v.dtype
352
+ assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
353
+ assert len(beta.shape) == 3, "beta must be of shape [B, H, T] if head_first=True, or [B, T, H] if head_first=False."
354
+
355
+ if cu_seqlens is not None:
356
+ if q.shape[0] != 1:
357
+ raise ValueError(
358
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
359
+ f"Please flatten variable-length inputs before processing."
360
+ )
361
+ if head_first:
362
+ raise RuntimeError(
363
+ "Sequences with variable lengths are not supported for head-first mode"
364
+ )
365
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
366
+ raise ValueError(
367
+ f"The number of initial states is expected to be equal to the number of input sequences, "
368
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
369
+ )
370
+ if head_first:
371
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
372
+ beta, g = map(lambda x: rearrange(x, 'b h t -> b t h'), (beta, g))
373
+ if scale is None:
374
+ scale = k.shape[-1] ** -0.5
375
+ else:
376
+ assert scale > 0, "Scale must be positive."
377
+ o, final_state = ChunkGatedDeltaRuleFunction.apply(
378
+ q,
379
+ k,
380
+ v,
381
+ g,
382
+ beta,
383
+ scale,
384
+ initial_state,
385
+ output_final_state,
386
+ cu_seqlens,
387
+ False,
388
+ use_qk_l2norm_in_kernel
389
+ )
390
+ if head_first:
391
+ o = rearrange(o, 'b t h v -> b h t v')
392
+ return o, final_state
fla/ops/generalized_delta_rule/README.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generalized Delta Rule
2
+
3
+ In delta rule we have the recurrence:
4
+
5
+ ```math
6
+ \mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{I}-\beta_t \mathbf{k}_t\mathbf{k}_t^T) + \beta_t \mathbf{v}_t\mathbf{k}_t^T
7
+ ```
8
+
9
+ This repository implements a delta rule variant where $\mathbf{I}$ is not necessarily an identity matrix; $\mathbf{k}_t$ in $\mathbf{I} - \beta_t \mathbf{k}_t\mathbf{k}_t^T$ might be different from input $\mathbf{k}_t$ in $\mathbf{v}_t\mathbf{k}_t^T$.
10
+
11
+ ## IPLR (Identity Plus Low Rank)
12
+
13
+ The first variant is IPLR, where we have:
14
+
15
+ ```math
16
+ \mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{I}+\mathbf{a}_t\mathbf{b}_t^T) + \mathbf{v}_t\mathbf{k}_t^T
17
+ ```
18
+
19
+ When $\mathbf{a}_t = -\beta_t \mathbf{k}_t$, $\mathbf{b}_t = \mathbf{k}_t$, $\mathbf{v}_t= \beta_t \mathbf{v}_t$, we recover the original delta rule. Since here the transition matrix is identity-plus-low-rank, we refer to this variant as IPLR.
20
+
21
+ ### Numerical Stability
22
+
23
+ $\mathbf{a}_t$ and $\mathbf{b}_t$ must be in opposite directions, that is, $\mathbf{b}_t = \lambda_t \mathbf{a}_t$ where $\lambda_t < 0$. For an understanding of why this is necessary, you can derive the eigenvalues of the transition matrix.
24
+
25
+ ## DPLR (Diagonal Plus Low Rank)
26
+
27
+ The second variant is DPLR, where we have:
28
+
29
+ ```math
30
+ \mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{D}_t+\mathbf{a}_t\mathbf{b}_t^T) + \mathbf{v}_t\mathbf{k}_t^T
31
+ ```
32
+
33
+ Here, $\mathbf{I}$ is replaced by a diagonal matrix $\mathbf{D}_t$. This transition matrix structure has been utilized in RWKV7.
34
+
35
+ ## Efficient Chunkwise Implementation
36
+
37
+ For detailed information about efficient chunkwise implementation, please refer to our [technical note](https://drive.google.com/file/d/1rJbO3dU4fe7OKG3w7Yg058z_BNIuavNF/view?usp=sharing).
fla/ops/generalized_delta_rule/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .dplr import chunk_dplr_delta_rule, fused_recurrent_dplr_delta_rule
2
+ from .iplr import chunk_iplr_delta_rule, fused_recurrent_iplr_delta_rule
3
+
4
+ __all__ = [
5
+ 'chunk_dplr_delta_rule',
6
+ 'fused_recurrent_dplr_delta_rule',
7
+ 'chunk_iplr_delta_rule',
8
+ 'fused_recurrent_iplr_delta_rule'
9
+ ]
fla/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (11.6 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-312.pyc ADDED
Binary file (30.6 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-312.pyc ADDED
Binary file (25.4 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-312.pyc ADDED
Binary file (13.2 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-312.pyc ADDED
Binary file (21.3 kB). View file
 
fla/ops/generalized_delta_rule/dplr/naive.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+ # S_t = S_t @ (I + alpha_t beta_t^T) + v_t k_t^T
7
+ # q, k, alpha, beta [B, H, L, D_K]
8
+ # v [B, H, L, D_V]
9
+
10
+
11
+ def dplr_recurrence(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True):
12
+ orig_dtype = q.dtype
13
+ b, h, l, d_k = q.shape
14
+ q, k, v, beta, gk = map(lambda x: x.float(), [q, k, v, beta, gk])
15
+ d_v = v.shape[-1]
16
+ o = torch.zeros_like(v)
17
+ S = torch.zeros(b, h, d_k, d_v).to(v)
18
+ q = q * (d_k ** -0.5)
19
+
20
+ if initial_state is not None:
21
+ S += initial_state
22
+
23
+ for i in range(l):
24
+ _k = k[:, :, i]
25
+ _q = q[:, :, i]
26
+ _v = v[:, :, i]
27
+ _alpha = alpha[:, :, i].clone()
28
+ _beta = beta[:, :, i].clone()
29
+ _kv = _k[..., None] * _v[..., None, :] + (S.clone() * _alpha[..., None]).sum(-2, keepdim=True) * _beta[..., None]
30
+ S = S.clone() * gk[:, :, i].exp()[..., None] + _kv
31
+ o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
32
+ S = None if output_final_state is False else S
33
+ return o.to(orig_dtype), S
34
+
35
+
36
+ def dplr_chunkwise(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True, chunk_size=32):
37
+ b, h, l, d_k = q.shape
38
+ d_v = v.shape[-1]
39
+ q = q * (d_k ** -0.5)
40
+ v = v
41
+ assert l % chunk_size == 0
42
+
43
+ S = k.new_zeros(b, h, d_k, d_v).to(q)
44
+ if initial_state is not None:
45
+ S += initial_state
46
+
47
+ # note that diagonal is masked.
48
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
49
+ q, k, v, alpha, beta, gk = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d',
50
+ c=chunk_size).float(), [q, k, v, alpha, beta, gk])
51
+
52
+ gk_cumsum = gk.cumsum(-2)
53
+
54
+ # v2 = (alpha @ k.transpose(-1, -2)).masked_fill_(mask, 0) @ v
55
+ A_ab = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
56
+ A_qk = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
57
+ A_ak = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
58
+ A_qb = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
59
+
60
+ for i in range(chunk_size):
61
+ alpha_i = alpha[:, :, :, i, None]
62
+ q_i = q[:, :, :, i, None]
63
+ gk_i = gk_cumsum[:, :, :, i, None]
64
+ mask = (torch.arange(chunk_size) <= i).to(q.device)
65
+ attn_i = (gk_i - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp()
66
+ A_qk[:, :, :, i, :] = (q_i * k * attn_i).sum(-1).clone()
67
+ A_qb[:, :, :, i, :] = (q_i * beta * attn_i).sum(-1).clone()
68
+ mask = (torch.arange(chunk_size) < i).to(q.device)
69
+ # shift by one.
70
+ attn_i = (gk_i - gk[:, :, :, i, None] - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp()
71
+ A_ab[:, :, :, i, :] = (alpha_i * beta * attn_i).sum(-1).clone()
72
+ A_ak[:, :, :, i, :] = (alpha_i * k * attn_i).sum(-1).clone()
73
+
74
+ A_ab = A_ab
75
+ for i in range(1, chunk_size):
76
+ A_ab[..., i, :i] = A_ab[..., i, :i].clone() + (A_ab[..., i, :, None].clone() * A_ab[..., :, :i].clone()).sum(-2)
77
+
78
+ A_ab = A_ab + torch.eye(chunk_size, dtype=torch.float, device=q.device)
79
+ u = A_ab @ (A_ak @ v)
80
+ w = A_ab @ ((gk_cumsum-gk).exp() * alpha)
81
+
82
+ o = torch.zeros_like(v)
83
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
84
+ for i in range(0, l // chunk_size):
85
+ q_i, k_i, v_i, u_i, w_i, beta_i = q[:, :, i], k[:, :, i], v[:, :, i], u[:, :, i], w[:, :, i], beta[:, :, i]
86
+ v2_i = u_i + w_i @ S
87
+
88
+ o_1 = A_qk[:, :, i] @ v_i
89
+ o_2 = A_qb[:, :, i] @ v2_i
90
+ o_3 = (q_i * gk_cumsum[:, :, i].exp()) @ S
91
+ o[:, :, i] = o_1 + o_2 + o_3
92
+ decay = (gk_cumsum[:, :, i, -1, None] - gk_cumsum[:, :, i]).exp()
93
+ S = S*gk_cumsum[:, :, i, -1, :, None].exp() + (k_i * decay).transpose(-1, -2) @ v_i + \
94
+ (beta_i * decay).transpose(-1, -2) @ v2_i
95
+ S = None if output_final_state is False else S
96
+ return rearrange(o, 'b h n c d -> b h (n c) d'), S
fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import check_shared_mem, is_intel_alchemist, use_cuda_graph
11
+
12
+ # https://github.com/intel/intel-xpu-backend-for-triton/issues/3449
13
+ triton_config = {'grf_mode': 'large'} if is_intel_alchemist else {}
14
+
15
+
16
+ @triton.heuristics({
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config(triton_config, num_warps=num_warps, num_stages=num_stages)
22
+ for num_warps in [2, 4, 8, 16, 32]
23
+ for num_stages in [2, 3, 4]
24
+ ],
25
+ key=['BT', 'BK', 'BV'],
26
+ use_cuda_graph=use_cuda_graph,
27
+ )
28
+ @triton.jit(do_not_specialize=['T'])
29
+ def bwd_prepare_wy_repr_kernel(
30
+ A_ab_inv,
31
+ A_ak,
32
+ ag,
33
+ v,
34
+ dw,
35
+ du,
36
+ dv,
37
+ dv0,
38
+ dag,
39
+ dAak,
40
+ dAab,
41
+ offsets,
42
+ indices,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BK: tl.constexpr,
49
+ BV: tl.constexpr,
50
+ USE_OFFSETS: tl.constexpr,
51
+ HEAD_FIRST: tl.constexpr
52
+ ):
53
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
54
+ i_b, i_h = i_bh // H, i_bh % H
55
+ if USE_OFFSETS:
56
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
57
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
58
+ T = eos - bos
59
+ else:
60
+ bos, eos = i_b * T, i_b * T + T
61
+
62
+ if HEAD_FIRST:
63
+ p_Aab_inv_t = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
64
+ p_Aak_t = tl.make_block_ptr(A_ak + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
65
+ p_dAak = tl.make_block_ptr(dAak + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
66
+ p_dAab = tl.make_block_ptr(dAab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
67
+ else:
68
+ p_Aak_t = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
69
+ p_Aab_inv_t = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
70
+ p_dAak = tl.make_block_ptr(dAak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
71
+ p_dAab = tl.make_block_ptr(dAab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
72
+
73
+ b_A_ab_inv_t = tl.load(p_Aab_inv_t, boundary_check=(0, 1))
74
+ b_A_ak_t = tl.load(p_Aak_t, boundary_check=(0, 1))
75
+ b_A_ak_t = tl.where(tl.arange(0, BT)[:, None] < tl.arange(0, BT)[None, :], b_A_ak_t, 0)
76
+ b_A_ab_inv_t = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A_ab_inv_t, 0)
77
+ b_A_tmp_t = tl.dot(b_A_ak_t, b_A_ab_inv_t).to(v.dtype.element_ty)
78
+ b_dA_tmp = tl.zeros([BT, BT], dtype=tl.float32)
79
+
80
+ for i_v in range(tl.cdiv(V, BV)):
81
+ if HEAD_FIRST:
82
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
83
+ p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
84
+ p_dv0 = tl.make_block_ptr(dv0 + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
85
+ p_du = tl.make_block_ptr(du + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
86
+ else:
87
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
88
+ p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
89
+ p_dv0 = tl.make_block_ptr(dv0 + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
90
+ p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
91
+ b_v = tl.load(p_v, boundary_check=(0, 1))
92
+ b_du = tl.load(p_du, boundary_check=(0, 1))
93
+ b_dA_tmp += tl.dot(b_du.to(b_v.dtype), tl.trans(b_v))
94
+ b_dv0 = tl.load(p_dv0, boundary_check=(0, 1))
95
+ b_dv = b_dv0 + tl.dot(b_A_tmp_t, b_du)
96
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
97
+
98
+ b_dA_tmp = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_tmp, 0)
99
+ b_dA_ak = tl.dot(b_A_ab_inv_t, b_dA_tmp)
100
+ b_dA_ak = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_ak, 0)
101
+ tl.store(p_dAak, b_dA_ak, boundary_check=(0, 1))
102
+ b_dA_ab_inv = tl.dot(b_dA_tmp, b_A_ak_t)
103
+
104
+ for i_k in range(tl.cdiv(K, BK)):
105
+ if HEAD_FIRST:
106
+ p_ag = tl.make_block_ptr(ag + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
107
+ p_dag = tl.make_block_ptr(dag + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
108
+ p_dw = tl.make_block_ptr(dw + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
109
+ else:
110
+ p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
111
+ p_dag = tl.make_block_ptr(dag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
112
+ p_dw = tl.make_block_ptr(dw + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
113
+ b_ag = tl.load(p_ag, boundary_check=(0, 1))
114
+ b_dw = tl.load(p_dw, boundary_check=(0, 1))
115
+ b_dA_ab_inv += tl.dot(b_dw, tl.trans(b_ag))
116
+ b_dag = tl.dot(b_A_ab_inv_t.to(b_dw.dtype), b_dw)
117
+ tl.store(p_dag, b_dag.to(p_dag.dtype.element_ty), boundary_check=(0, 1))
118
+
119
+ # if we know dL/dA^(-1), for dL/dA, we can use the following formula:
120
+ # dL/dA = -(A^(-1))^T @ (dL/dA^(-1)) @ (A^(-1))^T
121
+ # in the fwd pass we use fwd substitution to calculate (I-lower(A_ab))^-1.
122
+ # denote A = I - lower(A_ab), B = A^-1
123
+ # in the backward pass.
124
+ # dL/dA = -(B)^T @ (dL/dB) @ B^T
125
+ # dL/dA_ab = lower(B^T @ dL/dB @ B^T)
126
+ b_dA_ab_inv = tl.where(tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_dA_ab_inv, 0)
127
+ b_dA_ab_inv = tl.dot(b_A_ab_inv_t, b_dA_ab_inv)
128
+ b_dA_ab_inv = tl.dot(b_dA_ab_inv, b_A_ab_inv_t)
129
+ b_dA_ab_inv = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_ab_inv, 0)
130
+ tl.store(p_dAab, b_dA_ab_inv, boundary_check=(0, 1))
131
+
132
+
133
+ def chunk_dplr_bwd_wy(
134
+ A_ab_inv: torch.Tensor,
135
+ A_ak: torch.Tensor,
136
+ v: torch.Tensor,
137
+ ag: torch.Tensor,
138
+ dw: torch.Tensor,
139
+ du: torch.Tensor,
140
+ dv0: torch.Tensor,
141
+ offsets: Optional[torch.LongTensor],
142
+ indices: Optional[torch.LongTensor],
143
+ head_first: bool,
144
+ chunk_size: int,
145
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
146
+ A_ab_inv, A_ak, v, ag, dw, du = map(lambda x: x.contiguous(), [A_ab_inv, A_ak, v, ag, dw, du])
147
+ if head_first:
148
+ B, H, T, K, V = *dw.shape, du.shape[-1]
149
+ else:
150
+ B, T, H, K, V = *dw.shape, du.shape[-1]
151
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
152
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
153
+ BK = min(triton.next_power_of_2(K), 64)
154
+ BV = min(triton.next_power_of_2(V), 64) if check_shared_mem() else min(triton.next_power_of_2(V), 32)
155
+
156
+ dA_ab = torch.empty_like(A_ab_inv, dtype=torch.float)
157
+ dA_ak = torch.empty_like(A_ak, dtype=torch.float)
158
+ dv = torch.empty_like(v)
159
+ dag = torch.empty_like(ag)
160
+
161
+ bwd_prepare_wy_repr_kernel[(NT, B * H)](
162
+ A_ab_inv=A_ab_inv,
163
+ A_ak=A_ak,
164
+ ag=ag,
165
+ v=v,
166
+ dw=dw,
167
+ du=du,
168
+ dv=dv,
169
+ dv0=dv0,
170
+ dag=dag,
171
+ dAak=dA_ak,
172
+ dAab=dA_ab,
173
+ offsets=offsets,
174
+ indices=indices,
175
+ T=T,
176
+ H=H,
177
+ K=K,
178
+ V=V,
179
+ BT=BT,
180
+ BK=BK,
181
+ BV=BV,
182
+ HEAD_FIRST=head_first
183
+ )
184
+ return dA_ab, dA_ak, dv, dag
fla/ops/gla/fused_recurrent.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+
8
+ from fla.ops.common.fused_recurrent import fused_recurrent
9
+
10
+
11
+ def fused_recurrent_gla(
12
+ q: torch.Tensor,
13
+ k: torch.Tensor,
14
+ v: torch.Tensor,
15
+ gk: Optional[torch.Tensor] = None,
16
+ gv: Optional[torch.Tensor] = None,
17
+ scale: Optional[int] = None,
18
+ initial_state: Optional[torch.Tensor] = None,
19
+ output_final_state: bool = False,
20
+ reverse: bool = False,
21
+ cu_seqlens: Optional[torch.LongTensor] = None,
22
+ head_first: bool = True
23
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
24
+ r"""
25
+ Args:
26
+ q (torch.Tensor):
27
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
28
+ k (torch.Tensor):
29
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
30
+ v (torch.Tensor):
31
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
32
+ gk (torch.Tensor):
33
+ Forget gates of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` applied to keys.
34
+ gv (torch.Tensor):
35
+ Forget gates of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` applied to values.
36
+ scale (Optional[int]):
37
+ Scale factor for the attention scores.
38
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
39
+ initial_state (Optional[torch.Tensor]):
40
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
41
+ For equal-length input sequences, `N` equals the batch size `B`.
42
+ Default: `None`.
43
+ output_final_state (Optional[bool]):
44
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
45
+ reverse (Optional[bool]):
46
+ If `True`, process the state passing in reverse order. Default: `False`.
47
+ cu_seqlens (torch.LongTensor):
48
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
49
+ consistent with the FlashAttention API.
50
+ head_first (Optional[bool]):
51
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
52
+ Default: `True`.
53
+
54
+ Returns:
55
+ o (torch.Tensor):
56
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
57
+ final_state (torch.Tensor):
58
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
59
+
60
+ Examples::
61
+ >>> import torch
62
+ >>> import torch.nn.functional as F
63
+ >>> from einops import rearrange
64
+ >>> from fla.ops.gla import fused_recurrent_gla
65
+ # inputs with equal lengths
66
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
67
+ >>> q = torch.randn(B, T, H, K, device='cuda')
68
+ >>> k = torch.randn(B, T, H, K, device='cuda')
69
+ >>> v = torch.randn(B, T, H, V, device='cuda')
70
+ >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda'))
71
+ >>> h0 = torch.randn(B, H, K, V, device='cuda')
72
+ >>> o, ht = fused_recurrent_gla(q, k, v, g,
73
+ initial_state=h0,
74
+ output_final_state=True,
75
+ head_first=False)
76
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
77
+ >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g))
78
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
79
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
80
+ >>> o_var, ht_var = fused_recurrent_gla(q, k, v, g,
81
+ initial_state=h0,
82
+ output_final_state=True,
83
+ cu_seqlens=cu_seqlens,
84
+ head_first=False)
85
+ >>> assert o.allclose(o_var.view(o.shape))
86
+ >>> assert ht.allclose(ht_var)
87
+ """
88
+ if cu_seqlens is not None:
89
+ if q.shape[0] != 1:
90
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
91
+ f"Please flatten variable-length inputs before processing.")
92
+ if head_first:
93
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
94
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
95
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
96
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.")
97
+ if scale is None:
98
+ scale = k.shape[-1] ** -0.5
99
+ o, final_state = fused_recurrent(
100
+ q=q,
101
+ k=k,
102
+ v=v,
103
+ g=None,
104
+ gk=gk,
105
+ gv=gv,
106
+ scale=scale,
107
+ initial_state=initial_state,
108
+ output_final_state=output_final_state,
109
+ reverse=reverse,
110
+ cu_seqlens=cu_seqlens,
111
+ head_first=head_first
112
+ )
113
+ return o, final_state