zaydzuhri commited on
Commit
652030e
·
verified ·
1 Parent(s): 5f37b1a

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +21 -0
  2. README.md +512 -0
  3. configs/delta_net_1B.json +29 -0
  4. configs/delta_net_340M.json +27 -0
  5. configs/gla_340M.json +24 -0
  6. configs/gla_7B.json +25 -0
  7. configs/gsa_340M.json +29 -0
  8. configs/hgrn2_340M.json +20 -0
  9. configs/rectified_transformer_120M.json +19 -0
  10. configs/scaled_softpick_transformer_120M.json +19 -0
  11. configs/scaled_softpick_transformer_340M.json +19 -0
  12. configs/scaled_vanilla_transformer_340M.json +19 -0
  13. configs/softpick_transformer_120M.json +19 -0
  14. configs/softpick_transformer_1B.json +23 -0
  15. configs/softpick_transformer_340M.json +19 -0
  16. configs/softpick_transformer_7B.json +22 -0
  17. configs/softpick_transformer_with_pruning_340M.json +63 -0
  18. configs/transformer_120M.json +18 -0
  19. configs/transformer_7B.json +21 -0
  20. configs/vanilla_transformer_1B.json +23 -0
  21. configs/vanilla_transformer_340M.json +19 -0
  22. configs/vanilla_transformer_7B.json +22 -0
  23. download_checkpoint.py +35 -0
  24. fla/layers/abc.py +218 -0
  25. fla/layers/based.py +96 -0
  26. fla/layers/delta_net.py +291 -0
  27. fla/layers/forgetting_attn.py +109 -0
  28. fla/layers/gated_deltanet.py +293 -0
  29. fla/layers/gla.py +294 -0
  30. fla/layers/gsa.py +227 -0
  31. fla/layers/rebased.py +133 -0
  32. fla/layers/rwkv7.py +221 -0
  33. fla/modules/__init__.py +29 -0
  34. fla/modules/convolution.py +434 -0
  35. fla/ops/__pycache__/__init__.cpython-312.pyc +0 -0
  36. fla/ops/based/fused_chunk.py +374 -0
  37. fla/ops/common/chunk_h_parallel.py +650 -0
  38. fla/ops/common/fused_recurrent.py +575 -0
  39. fla/ops/delta_rule/fused_chunk.py +6 -0
  40. fla/ops/gated_delta_rule/__init__.py +7 -0
  41. fla/ops/gated_delta_rule/fused_recurrent.py +321 -0
  42. fla/ops/gated_delta_rule/wy_fast.py +620 -0
  43. fla/ops/gla/__init__.py +11 -0
  44. fla/ops/gla/fused_chunk.py +631 -0
  45. fla/ops/gsa/chunk.py +1264 -0
  46. fla/ops/gsa/naive.py +68 -0
  47. fla/ops/hgrn/fused_recurrent.py +308 -0
  48. fla/ops/hgrn/naive.py +63 -0
  49. fla/ops/rwkv4/fused_recurrent.py +476 -0
  50. fla/ops/rwkv6/__init__.py +9 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023-2025 Songlin Yang, Yu Zhang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # 🔥 Flame: Flash Linear Attention Made Easy
4
+ # This is a fork for the paper:
5
+ # Softpick: No Attention Sink, No Massive Activations with Rectified Softmax
6
+
7
+ </div>
8
+
9
+ ## Instructions for Softpick Attention
10
+
11
+ This fork can only work on an older commit of torchtitan and flame, so the setup looks like this:
12
+
13
+ ```bash
14
+ git clone https://github.com/zaydzuhri/flame.git
15
+ cd flame
16
+ git checkout softpick-attention
17
+ git submodule update --init --recursive --remote
18
+ cd 3rdparty/torchtitan
19
+ git checkout 4f532e0
20
+ cd ../../
21
+
22
+ pip install .
23
+ pip install flash-attn --no-build-isolation
24
+ ```
25
+ The flash-linear-attention submodule has been changed to link to our fork: https://github.com/zaydzuhri/flash-linear-attention/tree/softpick-attention
26
+ So no need to manually clone it.
27
+
28
+ Then prepare the fineweb-edu 100B sample the same way as described in the flame repo guide below.
29
+
30
+ These are the training commands used in the paper:
31
+ ```bash
32
+ NGPU=8 bash train.sh --job.config_file flame/models/fla.toml --job.dump_folder exp/vanilla.340M.batch16.seqlen4096.context4096.warmup1000.update1.steps100000.lr3e-4.cosine --model.config configs/vanilla_transformer_340M.json --model.tokenizer_path fla-hub/transformer-1.3B-100B --optimizer.name AdamW --optimizer.eps 1e-15 --optimizer.lr 3e-4 --lr_scheduler.warmup_steps 1000 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine --training.batch_size 16 --training.seq_len 4096 --training.context_len 4096 --training.gradient_accumulation_steps 1 --training.steps 100000 --training.max_norm 1.0 --training.skip_nan_inf --training.dataset ~/.cache/HuggingFaceFW___fineweb-edu/sample-100BT --training.dataset_split train --training.num_workers 32 --training.prefetch_factor 2 --training.seed 79 --training.compile --checkpoint.interval 10000 --checkpoint.load_step -1 --metrics.log_freq 5 --checkpoint.hf_upload_enabled --checkpoint.hf_repo_base_name "zaydzuhri/vanilla-340M-4096-batch16-steps100000" --comm.init_timeout_seconds 600 --comm.train_timeout_seconds 300
33
+
34
+ NGPU=8 bash train.sh --job.config_file flame/models/fla.toml --job.dump_folder exp/softpick.340M.batch16.seqlen4096.context4096.warmup1000.update1.steps100000.lr3e-4.cosine --model.config configs/softpick_transformer_340M.json --model.tokenizer_path fla-hub/transformer-1.3B-100B --optimizer.name AdamW --optimizer.eps 1e-15 --optimizer.lr 3e-4 --lr_scheduler.warmup_steps 1000 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine --training.batch_size 16 --training.seq_len 4096 --training.context_len 4096 --training.gradient_accumulation_steps 1 --training.steps 100000 --training.max_norm 1.0 --training.skip_nan_inf --training.dataset ~/.cache/HuggingFaceFW___fineweb-edu/sample-100BT --training.dataset_split train --training.num_workers 32 --training.prefetch_factor 2 --training.seed 79 --training.compile --checkpoint.interval 10000 --checkpoint.load_step -1 --metrics.log_freq 5 --checkpoint.hf_upload_enabled --checkpoint.hf_repo_base_name "zaydzuhri/softpick-340M-4096-batch16-steps100000" --comm.init_timeout_seconds 600 --comm.train_timeout_seconds 300
35
+ ```
36
+
37
+ And the same for the extra experiments in the appendix:
38
+ ```bash
39
+ NGPU=8 bash train.sh --job.config_file flame/models/fla.toml --job.dump_folder exp/rectified.340M.batch16.seqlen4096.context4096.warmup1000.update1.steps100000.lr3e-4.cosine --model.config configs/rectified_transformer_340M.json --model.tokenizer_path fla-hub/transformer-1.3B-100B --optimizer.name AdamW --optimizer.eps 1e-15 --optimizer.lr 3e-4 --lr_scheduler.warmup_steps 1000 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine --training.batch_size 16 --training.seq_len 4096 --training.context_len 4096 --training.gradient_accumulation_steps 1 --training.steps 100000 --training.max_norm 1.0 --training.skip_nan_inf --training.dataset ~/.cache/HuggingFaceFW___fineweb-edu/sample-100BT --training.dataset_split train --training.num_workers 32 --training.prefetch_factor 2 --training.seed 79 --training.compile --checkpoint.interval 10000 --checkpoint.load_step -1 --metrics.log_freq 5 --checkpoint.hf_upload_enabled --checkpoint.hf_repo_base_name "zaydzuhri/rectified-340M-4096-batch16-steps100000" --comm.init_timeout_seconds 600 --comm.train_timeout_seconds 300
40
+
41
+ NGPU=8 bash train.sh --job.config_file flame/models/fla.toml --job.dump_folder exp/softpick.scaled.340M.batch16.seqlen4096.context4096.warmup1000.update1.steps100000.lr3e-4.cosine --model.config configs/softpick_scaled_transformer_340M.json --model.tokenizer_path fla-hub/transformer-1.3B-100B --optimizer.name AdamW --optimizer.eps 1e-15 --optimizer.lr 3e-4 --lr_scheduler.warmup_steps 1000 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine --training.batch_size 16 --training.seq_len 4096 --training.context_len 4096 --training.gradient_accumulation_steps 1 --training.steps 100000 --training.max_norm 1.0 --training.skip_nan_inf --training.dataset ~/.cache/HuggingFaceFW___fineweb-edu/sample-100BT --training.dataset_split train --training.num_workers 32 --training.prefetch_factor 2 --training.seed 79 --training.compile --checkpoint.interval 10000 --checkpoint.load_step -1 --metrics.log_freq 5 --checkpoint.hf_upload_enabled --checkpoint.hf_repo_base_name "zaydzuhri/softpick-scaled-340M-4096-batch16-steps100000" --comm.init_timeout_seconds 600 --comm.train_timeout_seconds 300
42
+ ```
43
+
44
+ Feel free to DM @zmkzmkz on X for any questions regarding the paper or this code!
45
+
46
+ ## Flame
47
+
48
+ Welcome to 🔥 `flame`, a minimal and efficient framework built on `torchtitan` for training Flash Linear Attention (FLA) models (and more broadly, arbitrary autoregressive language models) with blazing efficiency.
49
+
50
+ **Feature Highlights:**
51
+
52
+ - 🚀 Minimal, easy-to-use, extensible training framework
53
+ - 🤗 Seamless integration with `fla` and `transformers`
54
+ - 🔄 Zero-cost data preprocessing: online tokenization, dataset shuffling, and multiple datasets support
55
+ - 🔮 4D parallelism (coming soon)
56
+
57
+ ## Setup
58
+
59
+ To get started, clone the `flame` repository and install the required dependencies:
60
+
61
+ ```bash
62
+ git clone https://github.com/fla-org/flame.git
63
+ cd flame
64
+ pip install .
65
+ ```
66
+
67
+ `flame` manages minimal dependencies, only including `fla` and `torchtitan` as submodules.
68
+ After installation, initialize and update the submodules:
69
+ ```sh
70
+ git submodule update --init --recursive
71
+ ```
72
+
73
+ ## Dataset Preparation
74
+ To download the dataset to your local disk, create a new Python file with the following content and execute it:
75
+
76
+ ```py
77
+ from datasets import load_dataset
78
+
79
+ # load fineweb-edu with parallel processing
80
+ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="default", num_proc=64, cache_dir="/your/cache/path")
81
+
82
+ # or load a subset with roughly 100B tokens, suitable for small- or medium-sized experiments
83
+ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-100BT", num_proc=64, cache_dir="/your/cache/path")
84
+ ```
85
+
86
+ ## Training Recipes
87
+
88
+ Here's an example of training a 340M FLA Transformer model with a LLaMA-like architecture from scratch on a 100BT subset of the Fineweb-edu corpus in streaming mode.
89
+
90
+ > [!WARNING]
91
+ > If the dataset is not downloaded beforehand, the streaming mode will attempt to fetch it from a remote server and download it on-the-fly, which can be highly unstable during training due to network issues.
92
+ > For stable training, ensure the dataset is downloaded locally (see [**Dataset Preparation**](#dataset-preparation)). Otherwise, we assume you are only testing the new corpus.
93
+
94
+ ```sh
95
+ bash train.sh \
96
+ --job.config_file flame/models/fla.toml \
97
+ --job.dump_folder exp/transformer-340M-4K-10B/batch1.seqlen65536.context4096.warmup1024.update1.steps20480.lr3e-4.cosine \
98
+ --model.config configs/transformer_340M.json \
99
+ --model.tokenizer_path fla-hub/transformer-1.3B-100B \
100
+ --optimizer.name AdamW \
101
+ --optimizer.eps 1e-15 \
102
+ --optimizer.lr 3e-4 \
103
+ --lr_scheduler.warmup_steps 1024 \
104
+ --lr_scheduler.lr_min 0.1 \
105
+ --lr_scheduler.decay_type cosine \
106
+ --training.batch_size 1 \
107
+ --training.seq_len 65536 \
108
+ --training.context_len 4096 \
109
+ --training.varlen \
110
+ --training.gradient_accumulation_steps 1 \
111
+ --training.steps 20480 \
112
+ --training.max_norm 1.0 \
113
+ --training.skip_nan_inf \
114
+ --training.dataset HuggingFaceFW/fineweb-edu \
115
+ --training.dataset_name sample-100BT \
116
+ --training.dataset_split train \
117
+ --training.streaming \
118
+ --training.num_workers 32 \
119
+ --training.prefetch_factor 2 \
120
+ --training.seed 42 \
121
+ --training.compile \
122
+ --checkpoint.interval 2048 \
123
+ --checkpoint.load_step -1 \
124
+ --checkpoint.keep_latest_k 2 \
125
+ --metrics.log_freq 1
126
+ ```
127
+
128
+ You can specify the number of GPUs by setting the environment variable `NGPU`, which defaults to 8.
129
+ **For single-GPU debugging, set `NGPU=1`.**
130
+
131
+ We provide several [config files](https://github.com/fla-org/flame/tree/main/configs) for different models.
132
+ By default, the learning rate is set to 3e-4 with a cosine scheduler. Other schedulers, such as WSD (wsd), are also supported.
133
+
134
+ **Key parameters:**
135
+ - `--lr_scheduler.decay_ratio`: The proportion of the steps allocated to the decay phase. The learning rate will remain stable after the warmup period and only start decaying during the last `decay_ratio` portion of the total training steps, which is known as the Warmup-Stable-Decay (WSD) schedule.
136
+ - `--lr_scheduler.warmup_steps`: The number of steps for the learning rate warmup phase.
137
+ - `--training.steps`: Total number of training steps.
138
+ - `--training.batch_size`: Batch size per device, must be 1 if `--training.varlen` is set.
139
+ - `--training.seq_len`: The length of each sequence in the batch, which is concatenated from multiple samples.
140
+ - `--training.context_len`: The max allowed length of a sample. For non-varlen mode, this is equivalent to `seq_len`.
141
+ - `--training.varlen`: Whether to conduct variable-length sequence training.
142
+ - `--training.gradient_accumulation_steps`: Number of gradient accumulation steps.
143
+
144
+ > [!WARNING]
145
+ > The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as batch_size × gradient_accumulation_steps × num_gpus.
146
+ > Each step processes `global_batch_size * seq_len` tokens.
147
+ > Monitor the value of `global_batch_size`, `warmup_steps`, and `steps` carefully when modifying any of the hyperparameters!
148
+
149
+ For a detailed explanation of all parameters, run:
150
+
151
+ ```sh
152
+ bash train.sh -h
153
+ ```
154
+
155
+ <details>
156
+ <summary>Usage</summary>
157
+
158
+ ```py
159
+ options:
160
+ -h, --help show this help message and exit
161
+ --job.config_file JOB.CONFIG_FILE
162
+ Job config file
163
+ --job.dump_folder JOB.DUMP_FOLDER
164
+ Folder to dump job outputs
165
+ --job.description JOB.DESCRIPTION
166
+ Description of the job
167
+ --job.use_for_integration_test
168
+ Add this config to the integration test suite
169
+ --job.print_args Print the args to terminal
170
+ --model.config MODEL.CONFIG
171
+ Path to the model config
172
+ --model.norm_type MODEL.NORM_TYPE
173
+ Type of layer normalization to use [layernorm,
174
+ np_layernorm, rmsnorm, fused_rmsnorm]
175
+ --model.tokenizer_path MODEL.TOKENIZER_PATH
176
+ Tokenizer path
177
+ --profiling.enable_profiling
178
+ Whether to enable pytorch profiler
179
+ --profiling.save_traces_folder PROFILING.SAVE_TRACES_FOLDER
180
+ Trace files location
181
+ --profiling.profile_freq PROFILING.PROFILE_FREQ
182
+ How often to collect profiler traces, in iterations
183
+ --profiling.enable_memory_snapshot
184
+ Whether to dump memory snapshot
185
+ --profiling.save_memory_snapshot_folder PROFILING.SAVE_MEMORY_SNAPSHOT_FOLDER
186
+ Memeory snapshot files location
187
+ --optimizer.name OPTIMIZER.NAME
188
+ Optimizer to use
189
+ --optimizer.eps OPTIMIZER.EPS
190
+ Epsilon value for the optimizer.
191
+ --optimizer.fused Whether the fused implementation(CUDA only) is used.
192
+ --optimizer.scheduler {wsd,cosine,linear}
193
+ Scheduler to use. Currently supported: wsd, cosine,
194
+ and linear.
195
+ --optimizer.lr OPTIMIZER.LR
196
+ Learning rate to use
197
+ --optimizer.min_lr_ratio OPTIMIZER.MIN_LR_RATIO
198
+ Min lr ratio for lr scheduler
199
+ --optimizer.early_step_in_backward
200
+ Whether to apply optimizer in the backward. Caution,
201
+ optimizer_in_backward is not compatible with gradients
202
+ clipping, users should not call
203
+ register_post_accumulate_grad_hook after the optimizer
204
+ is built.
205
+ --training.batch_size TRAINING.BATCH_SIZE
206
+ Batch size
207
+ --training.seq_len TRAINING.SEQ_LEN
208
+ Sequence length
209
+ --training.context_len TRAINING.CONTEXT_LEN
210
+ Max length allowed for each sequence
211
+ --training.varlen Whether to take sequences of variable length as input
212
+ --training.warmup_steps TRAINING.WARMUP_STEPS
213
+ Steps for lr scheduler warmup, normally 1/5 of
214
+ --training.steps
215
+ --training.gradient_accumulation_steps TRAINING.GRADIENT_ACCUMULATION_STEPS
216
+ Number of steps to accumulate gradients before
217
+ updating parameters
218
+ --training.steps TRAINING.STEPS
219
+ How many train steps to run
220
+ --training.max_norm TRAINING.MAX_NORM
221
+ Max norm for gradient clipping
222
+ --training.skip_nan_inf
223
+ Skip batch updates when NaN or INF gradients are
224
+ encountered during training
225
+ --training.dataset TRAINING.DATASET
226
+ Dataset to use, with comma separated values
227
+ --training.dataset_name TRAINING.DATASET_NAME
228
+ The name of the dataset config, with comma separated
229
+ values if provided
230
+ --training.dataset_split TRAINING.DATASET_SPLIT
231
+ Dataset split to use, with comma separated values if
232
+ provided
233
+ --training.data_dir TRAINING.DATA_DIR
234
+ Data dirs to use, with comma separated values if
235
+ provided
236
+ --training.data_files TRAINING.DATA_FILES
237
+ Data files to use, with comma separated values if
238
+ provided
239
+ --training.data_probs TRAINING.DATA_PROBS
240
+ Data sampling probabilities, with comma separated
241
+ values if provided
242
+ --training.streaming Whether to load dataset in streaming mode, used for
243
+ huge dataset
244
+ --training.num_workers TRAINING.NUM_WORKERS
245
+ Number of subprocesses to use for data loading. 0
246
+ means that the data will be loaded in the main
247
+ process.
248
+ --training.prefetch_factor TRAINING.PREFETCH_FACTOR
249
+ Number of batches loaded in advance by each worker.2
250
+ means there will be a total of 2 * num_workers batches
251
+ prefetched across all workers.
252
+ --training.data_parallel_replicate_degree TRAINING.DATA_PARALLEL_REPLICATE_DEGREE
253
+ The `data_parallel_replicate_degree` argument
254
+ specifies the degree of data parallelism for weight
255
+ replication. When this value is greater than 1,
256
+ weights will be replicated across
257
+ `data_parallel_replicate_degree` ranks. If
258
+ `data_parallel_shard_degree` is also greater than 1,
259
+ the parallelism method used is HSDP (Hybrid Sharded
260
+ Data Parallelism). Otherwise, the parallelism method
261
+ used is DDP (Distributed Data Parallelism). 1 means
262
+ disabled.
263
+ --training.data_parallel_shard_degree TRAINING.DATA_PARALLEL_SHARD_DEGREE
264
+ The `data_parallel_shard_degree` argument specifies
265
+ the degree of data parallelism for weight sharding.
266
+ When this value is greater than 1, weights will be
267
+ sharded across `data_parallel_shard_degree` ranks. If
268
+ `data_parallel_replicate_degree` is also greater than
269
+ 1, the parallelism method used is HSDP (Hybrid Sharded
270
+ Data Parallelism). Otherwise, the parallelism method
271
+ used is FSDP (Fully Sharded Data Parallelism). -1
272
+ means leftover ranks will be used (After
273
+ DP_REPLICATE/SP/PP). Note that only
274
+ `data_parallel_shard_degree` can be negative. 1 means
275
+ disabled.
276
+ --training.enable_cpu_offload
277
+ Whether to apply CPU offloading of parameters,
278
+ gradients, and optimizer states in FSDP
279
+ --training.tensor_parallel_degree TRAINING.TENSOR_PARALLEL_DEGREE
280
+ Tensor Parallelism degree. 1 means disabled.
281
+ --training.disable_loss_parallel
282
+ Whether to apply loss parallel when sequence parallel
283
+ is enabled
284
+ --training.mixed_precision_param {bfloat16,float32}
285
+ torch dtype to use for parameters when applying mixed
286
+ precision via FSDP. This feature only takes effect
287
+ when data_parallel_shard_degree > 1
288
+ --training.mixed_precision_reduce {float32}
289
+ torch dtype to use for reductions when applying mixed
290
+ precision via FSDP. This feature only takes effect
291
+ when data_parallel_shard_degree > 1
292
+ --training.compile Whether to compile the model
293
+ --training.gc_freq TRAINING.GC_FREQ
294
+ Python garbage control scheduling interval, in steps
295
+ --training.seed TRAINING.SEED
296
+ Choose the base RNG seed used for training
297
+ --training.deterministic
298
+ Use deterministic algorithms wherever possible, may be
299
+ slower
300
+ --metrics.log_freq METRICS.LOG_FREQ
301
+ How often to log metrics to TensorBoard, in iterations
302
+ --metrics.enable_tensorboard
303
+ Whether to log metrics to TensorBoard
304
+ --metrics.disable_color_printing
305
+ Whether to disable color printing in logs
306
+ --metrics.save_tb_folder METRICS.SAVE_TB_FOLDER
307
+ Folder to dump TensorBoard states
308
+ --metrics.rank_0_only
309
+ Whether to save TensorBoard metrics only for rank 0 or
310
+ for all ranks. When pipeline_parallel_degree is > 1,
311
+ this option uses the 0th rank of the last stage
312
+ pipeline group, which is the only stage that computes
313
+ loss metrics.
314
+ --metrics.enable_wandb
315
+ Whether to log metrics to Weights & Biases
316
+ --experimental.enable_async_tensor_parallel
317
+ Whether to apply async tensor parallel (currently only
318
+ effective when compile is enabled)
319
+ --experimental.pipeline_parallel_degree EXPERIMENTAL.PIPELINE_PARALLEL_DEGREE
320
+ Pipeline Parallelism degree, or number of ranks. 1
321
+ means disabled. If using looped schedules, this still
322
+ specifies the number of physical ranks, not the number
323
+ of stages. Stages per rank are inferred from split
324
+ points degree, and schedule.
325
+ --experimental.pipeline_parallel_split_points EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS [EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS ...]
326
+ Specify comma-separated names of modules to use as the
327
+ beginning of a split point. e.g. "layers.0,layers.2"
328
+ will cause the model to be split into 3 stages, the
329
+ first containing all the layers up to layers.0, the
330
+ second containing layers.0 and up to layers.2, the
331
+ third containing layers.2 and all the remaining
332
+ layers. Note: fully-automated splitting may be enabled
333
+ in the future, but currently the split points must be
334
+ specified manually.
335
+ --experimental.pipeline_parallel_schedule EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE
336
+ Specify the Pipeline Parallel schedule to use. The
337
+ supported schedules are: https://github.com/pytorch/py
338
+ torch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/to
339
+ rch/distributed/pipelining/schedules.py#L2161. The
340
+ schedule must be compatible with the split points and
341
+ stages_per_rank. Looped schedules (e.g.
342
+ Interleaved1F1B) require specifying
343
+ pipeline_parallel_degree = number of ranks, and
344
+ split_points = number of stages - 1
345
+ --experimental.pipeline_parallel_schedule_csv EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE_CSV
346
+ Specify the path to the pipeline parallel schedule csv
347
+ file to use. The pipeline_parallel_schedule argument
348
+ must be either PipelineScheduleSingle,
349
+ PipelineScheduleMulti, or _PipelineScheduleRuntime.
350
+ --experimental.pipeline_parallel_microbatches EXPERIMENTAL.PIPELINE_PARALLEL_MICROBATCHES
351
+ How many microbatches to split the global training
352
+ batch into when using pipeline parallelism. The global
353
+ training batch size must be evenly divisible by the
354
+ number of microbatches. The default value will be the
355
+ number of pipeline stages, if unspecified.
356
+ --experimental.enable_compiled_autograd
357
+ Enable CompiledAutograd to compile the backward.
358
+ --experimental.context_parallel_degree EXPERIMENTAL.CONTEXT_PARALLEL_DEGREE
359
+ Context parallelism degree. 1 means disabled.
360
+ --experimental.context_parallel_rotate_method EXPERIMENTAL.CONTEXT_PARALLEL_ROTATE_METHOD
361
+ The collective to use in context parallel SDPA for kv
362
+ shards exchange. 'allgather' means to all-gather all
363
+ kv shards on ranks after the first sub-SDPA
364
+ computation, 'alltoall' means to all-to-all shuffle
365
+ the kv shards. The default value is 'allgather'.
366
+ --checkpoint.enable_checkpoint
367
+ Whether to enable checkpoint
368
+ --checkpoint.folder CHECKPOINT.FOLDER
369
+ The folder to store the checkpoints. When
370
+ enable_checkpoint is set to true, checkpoints will be
371
+ in {--job.dump_folder}/{--checkpoint.folder}.
372
+ --checkpoint.interval_type CHECKPOINT.INTERVAL_TYPE
373
+ Checkpointing interval unit of measurement ['step',
374
+ 'seconds']
375
+ --checkpoint.interval CHECKPOINT.INTERVAL
376
+ Checkpointing interval, in steps or seconds depending
377
+ on --checkpoint.interval_type
378
+ --checkpoint.model_weights_only
379
+ When model_weights_only=True, only model weights will
380
+ be saved at the end of training. With this,
381
+ checkpoints can be loaded using `torch.load(...,
382
+ weights_only=True)` after conversion. When
383
+ model_weights_only=False, the full checkpoint will be
384
+ saved. A full checkpoint includes model, optimizer and
385
+ train_state, which can be used to resume training. The
386
+ default value is false.
387
+ --checkpoint.export_dtype {float16,bfloat16,float32}
388
+ Converts to the specified precision when training
389
+ completes and model_weights_only=true. Currently
390
+ supports float32, float16, and bfloat16. The default
391
+ value is float32.
392
+ --checkpoint.create_seed_checkpoint
393
+ Initializes the full model without applying
394
+ parallelisms, and then saves it as a seed checkpoint.
395
+ Note: requires user to call train.py without
396
+ specifying any parallelisms, e.g. NGPU=1. Could be
397
+ implemented as a separate script, but this way shares
398
+ more code.
399
+ --checkpoint.async_mode CHECKPOINT.ASYNC_MODE
400
+ Which async checkpoint mode to use. Currently there
401
+ are 3 different modes. 1. "disabled": synchronized
402
+ checkpointing will be used. 2. "async":
403
+ torch.distributed.checkpoint.async_save will be used.
404
+ 1. "async_with_pinned_mem": this option utilizes a
405
+ dedicated pinned memory space and creates a separate
406
+ process for faster GPU->CPU transfer performance and
407
+ eliminating GIL contention. The cost is increased CPU
408
+ memory usage. If insufficient CPU memory is available,
409
+ performance may degrade due to memory paging. For most
410
+ users, "async" should suffice as the performance
411
+ overhead is typically small (on the order of tens of
412
+ seconds) compared to checkpointing frequency. This
413
+ mode can be employed to pursue near-zero checkpointing
414
+ times (e.g., < 1 second) given appropriate hardware
415
+ support such as ample CPU memory and fast PCIe.
416
+ "disabled" is the default mode.
417
+ --checkpoint.keep_latest_k CHECKPOINT.KEEP_LATEST_K
418
+ Keeps only the latest k checkpoints, and purging older
419
+ ones. If 0, keep all checkpoints. 0 is the default
420
+ value.
421
+ --checkpoint.load_step CHECKPOINT.LOAD_STEP
422
+ Load the checkpoint at the specified step. If -1, load
423
+ the latest checkpoint.
424
+ --float8.enable_float8_linear
425
+ If true, swaps `torch.nn.Linear` with `Float8Linear`.
426
+ This feature requires you to install 'torchao' which
427
+ can be found here: https://github.com/pytorch/ao
428
+ --float8.enable_fsdp_float8_all_gather
429
+ Whether enable float8 all-gather in FSDP
430
+ --float8.precompute_float8_dynamic_scale_for_fsdp
431
+ Whether precompute float8 scales dynamically for FSDP
432
+ --float8.scaling_type_input {dynamic,delayed}
433
+ float8 scaling for input, dynamic (default) or delayed
434
+ --float8.scaling_type_weight FLOAT8.SCALING_TYPE_WEIGHT
435
+ float8 scaling for input, dynamic (default) or delayed
436
+ --float8.scaling_type_grad_output FLOAT8.SCALING_TYPE_GRAD_OUTPUT
437
+ float8 scaling for input, dynamic (default) or delayed
438
+ --comm.init_timeout_seconds COMM.INIT_TIMEOUT_SECONDS
439
+ Timeout for communication operations, during
440
+ initialization and first train step.
441
+ --comm.train_timeout_seconds COMM.TRAIN_TIMEOUT_SECONDS
442
+ Timeout for communication operations after the first
443
+ train step -- usually a tighter bound than during
444
+ initialization.
445
+ --comm.trace_buf_size COMM.TRACE_BUF_SIZE
446
+ Flight recorder ring buffer size, >0 means recording
447
+ by default, 0 means disabled
448
+ --memory_estimation.enabled
449
+ Whether to estimate memory usage for FSDP
450
+ --memory_estimation.disable_fake_mode
451
+ Whether to estimate memory under FakeTensorMode
452
+ ```
453
+ </details>
454
+
455
+ ### Training with `torch.compile`
456
+
457
+ Starting from `torch 2.0`, `torch.compile` has been introduced as a new feature to seamlessly accelerate training processes.
458
+ In `flame`, one can simply enable `torch.compile` by adding `--training.compile` flag to your training script.
459
+
460
+ However, `fla` has integrated numerous fused kernels for acceleration, which may potentially conflict with `torch.compile`.
461
+ We are actively working on resolving these issues to make compilation transparent to users.
462
+ In the meantime, please ensure you are using the latest dependencies.
463
+
464
+ Specifically, **we recommend using `torch>=2.6` and `triton>=3.0`**.
465
+
466
+ ### Training with multiple datasets
467
+
468
+ If you wish to train a model with all-round capabilities (e.g., code, math, and multilingual ability), it's necessary to train on multiple datasets.
469
+ `flame` allows training with multiple datasets easily.
470
+ For example, you can specify the following arguments to train on 6 datasets with different proportions:
471
+
472
+ ```sh
473
+ --training.dataset HuggingFaceFW/fineweb-edu,opencsg/Fineweb-Edu-Chinese-V2.1,OpenCoder-LLM/opc-fineweb-code-corpus,math-ai/AutoMathText,EleutherAI/proof-pile-2,OpenCoder-LLM/opc-fineweb-math-corpus \
474
+ --training.data_probs 0.6,0.15,0.15,0.014,0.058,0.028 \
475
+ ```
476
+
477
+ ### ~Finalizing training~
478
+
479
+ > [!NOTE]
480
+ > We have done this conversion automatically in the training script since our latest updates.
481
+
482
+ Once training is complete, you may want to convert the distributed checkpoints (DCPs) into the 🤗 format for broader use.
483
+ To facilitate this, we provide a straightforward conversion script:
484
+
485
+ ```sh
486
+ python -m flame.utils.convert_dcp_to_hf --path <path_to_model> --step <step> --config <path_to_config> --tokenizer <path_to_tokenizer>
487
+ ```
488
+ After this, your model will be in the 🤗 format, ready to be shared or deployed.
489
+ You can then easily publish your model using the `huggingface_hub` for wider accessibility.
490
+
491
+ ### Continual training
492
+
493
+ If you wish to build upon a strong pre-trained model (in 🤗 format) and continue training, we also offer a script to convert the 🤗 format model back into DCP format.
494
+ This allows you to seamlessly resume training with `flame`.
495
+ ```sh
496
+ python -m flame.utils.convert_hf_to_dcp --model <path_to_hf> --checkpoint <path_to_dcp/checkpoint/step-0>
497
+ ```
498
+ Here, `<path_to_dcp>` is the directory where your distributed checkpoints will be stored.
499
+ The checkpoint is intentionally saved at `<step-0>` within the checkpoint folder to ensure it is loadable by `flame` during the initial training step, similar to how a seed checkpoint is handled.
500
+
501
+ Once the conversion is complete, you can proceed with training using `flame` as usual, continuing from where the pretrained model left off.
502
+
503
+ ## Multi-node training
504
+
505
+ If you have access to multi-node GPUs, consider leveraging them for optimal performance.
506
+ This process is straightforward and well-documented in the PyTorch [docs](https://pytorch.org/docs/stable/elastic/run.html).
507
+
508
+ To set up multi-node training:
509
+ * Set the environment variables `MASTER_ADDR=<ip>` and `MASTER_PORT=<port>` before running the training script across all nodes.
510
+ * If you're using a job scheduler like Slurm, it will handle these variables for you.
511
+
512
+ `torchtitan` provides a [Slurm script](https://github.com/pytorch/torchtitan/blob/main/multinode_trainer.slurm) for multi-node training, which you can use as a reference or starting point.
configs/delta_net_1B.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn": null,
3
+ "attn_mode": "chunk",
4
+ "bos_token_id": 1,
5
+ "conv_size": 4,
6
+ "eos_token_id": 2,
7
+ "expand_k": 1,
8
+ "expand_v": 1,
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "hidden_act": "swish",
12
+ "hidden_ratio": 4,
13
+ "hidden_size": 2048,
14
+ "initializer_range": 0.006,
15
+ "intermediate_size": null,
16
+ "model_type": "delta_net",
17
+ "norm_eps": 1e-06,
18
+ "num_heads": 16,
19
+ "num_hidden_layers": 24,
20
+ "pad_token_id": 2,
21
+ "qk_activation": "silu",
22
+ "qk_norm": "l2",
23
+ "tie_word_embeddings": false,
24
+ "use_beta": true,
25
+ "use_cache": true,
26
+ "use_gate": false,
27
+ "use_output_norm": true,
28
+ "use_short_conv": true
29
+ }
configs/delta_net_340M.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "conv_size": 4,
5
+ "eos_token_id": 2,
6
+ "expand_k": 1,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "hidden_act": "swish",
10
+ "hidden_ratio": 4,
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.006,
13
+ "intermediate_size": null,
14
+ "model_type": "delta_net",
15
+ "norm_eps": 1e-06,
16
+ "norm_first": false,
17
+ "num_heads": 8,
18
+ "num_hidden_layers": 24,
19
+ "qk_activation": "silu",
20
+ "qk_norm": "l2",
21
+ "tie_word_embeddings": false,
22
+ "use_beta": true,
23
+ "use_cache": true,
24
+ "use_gate": false,
25
+ "use_output_norm": true,
26
+ "use_short_conv": true
27
+ }
configs/gla_340M.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "clamp_min": null,
5
+ "eos_token_id": 2,
6
+ "expand_k": 0.5,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.006,
14
+ "intermediate_size": null,
15
+ "model_type": "gla",
16
+ "num_heads": 4,
17
+ "num_hidden_layers": 24,
18
+ "norm_eps": 1e-06,
19
+ "tie_word_embeddings": false,
20
+ "use_cache": true,
21
+ "use_gk": true,
22
+ "use_gv": false,
23
+ "vocab_size": 32000
24
+ }
configs/gla_7B.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn": null,
3
+ "attn_mode": "chunk",
4
+ "bos_token_id": 1,
5
+ "eos_token_id": 2,
6
+ "expand_k": 0.5,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 4096,
13
+ "initializer_range": 0.006,
14
+ "intermediate_size": 11008,
15
+ "model_type": "gla",
16
+ "norm_eps": 1e-06,
17
+ "num_heads": 16,
18
+ "num_hidden_layers": 32,
19
+ "tie_word_embeddings": false,
20
+ "use_cache": true,
21
+ "use_gk": true,
22
+ "use_gv": false,
23
+ "use_output_gate": true,
24
+ "use_short_conv": false
25
+ }
configs/gsa_340M.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "conv_size": 4,
4
+ "eos_token_id": 2,
5
+ "expand_k": 1,
6
+ "expand_v": 1,
7
+ "elementwise_affine": false,
8
+ "feature_map": "swish",
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "gate_logit_normalizer": 4,
12
+ "hidden_act": "swish",
13
+ "hidden_ratio": 4,
14
+ "hidden_size": 1024,
15
+ "initializer_range": 0.006,
16
+ "intermediate_size": null,
17
+ "model_type": "gsa",
18
+ "num_heads": 4,
19
+ "num_hidden_layers": 24,
20
+ "num_slots": 64,
21
+ "norm_eps": 1e-06,
22
+ "share_conv_kernel": true,
23
+ "tie_word_embeddings": false,
24
+ "use_cache": true,
25
+ "use_norm": true,
26
+ "use_output_gate": true,
27
+ "use_rope": false,
28
+ "use_short_conv": false
29
+ }
configs/hgrn2_340M.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "expand_ratio": 128,
6
+ "fuse_cross_entropy": true,
7
+ "fuse_norm": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 1024,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "model_type": "hgrn2",
14
+ "num_heads": 8,
15
+ "num_hidden_layers": 24,
16
+ "norm_eps": 1e-06,
17
+ "tie_word_embeddings": false,
18
+ "use_cache": true,
19
+ "vocab_size": 32000
20
+ }
configs/rectified_transformer_120M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "naive_rectified_attn"
19
+ }
configs/scaled_softpick_transformer_120M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "parallel_scaled_softpick_attn"
19
+ }
configs/scaled_softpick_transformer_340M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "parallel_scaled_softpick_attn"
19
+ }
configs/scaled_vanilla_transformer_340M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "parallel_scaled_attn"
19
+ }
configs/softpick_transformer_120M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "naive_softpick_attn"
19
+ }
configs/softpick_transformer_1B.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "elementwise_affine": true,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "fuse_swiglu": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "max_position_embeddings": 8192,
14
+ "model_type": "transformer",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 32,
17
+ "num_hidden_layers": 32,
18
+ "num_kv_heads": null,
19
+ "pad_token_id": 2,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false,
22
+ "attn_impl": "parallel_softpick_attn"
23
+ }
configs/softpick_transformer_340M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "parallel_softpick_attn"
19
+ }
configs/softpick_transformer_7B.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_ratio": 4,
9
+ "hidden_size": 4096,
10
+ "initializer_range": 0.006,
11
+ "intermediate_size": 14336,
12
+ "model_type": "transformer",
13
+ "norm_eps": 1e-06,
14
+ "num_heads": 32,
15
+ "num_hidden_layers": 32,
16
+ "num_kv_heads": 8,
17
+ "rope_theta": 10000.0,
18
+ "tie_word_embeddings": false,
19
+ "use_cache": true,
20
+ "window_size": null,
21
+ "attn_impl": "parallel_softpick_attn"
22
+ }
configs/softpick_transformer_with_pruning_340M.json ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "attn_impl": "parallel_softpick_attn",
4
+ "bos_token_id": 1,
5
+ "elementwise_affine": true,
6
+ "eos_token_id": 2,
7
+ "fuse_cross_entropy": true,
8
+ "fuse_norm": true,
9
+ "fuse_swiglu": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.006,
14
+ "intermediate_size": null,
15
+ "layer_head_pruned": [
16
+ [
17
+ 2,
18
+ 1
19
+ ],
20
+ [
21
+ 2,
22
+ 7
23
+ ],
24
+ [
25
+ 2,
26
+ 12
27
+ ],
28
+ [
29
+ 2,
30
+ 13
31
+ ],
32
+ [
33
+ 3,
34
+ 5
35
+ ],
36
+ [
37
+ 3,
38
+ 13
39
+ ],
40
+ [
41
+ 3,
42
+ 14
43
+ ],
44
+ [
45
+ 13,
46
+ 6
47
+ ]
48
+ ],
49
+ "max_position_embeddings": 8192,
50
+ "model_type": "transformer_with_pruning",
51
+ "norm_eps": 1e-06,
52
+ "num_heads": 16,
53
+ "num_hidden_layers": 24,
54
+ "num_kv_heads": null,
55
+ "qk_norm": false,
56
+ "qkv_bias": false,
57
+ "rope_theta": 10000.0,
58
+ "tie_word_embeddings": false,
59
+ "transformers_version": "4.51.3",
60
+ "use_cache": true,
61
+ "vocab_size": 32000,
62
+ "window_size": null
63
+ }
configs/transformer_120M.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000
18
+ }
configs/transformer_7B.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_ratio": 4,
9
+ "hidden_size": 4096,
10
+ "initializer_range": 0.006,
11
+ "intermediate_size": 14336,
12
+ "model_type": "transformer",
13
+ "norm_eps": 1e-06,
14
+ "num_heads": 32,
15
+ "num_hidden_layers": 32,
16
+ "num_kv_heads": 8,
17
+ "rope_theta": 10000.0,
18
+ "tie_word_embeddings": false,
19
+ "use_cache": true,
20
+ "window_size": null
21
+ }
configs/vanilla_transformer_1B.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "elementwise_affine": true,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "fuse_swiglu": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "max_position_embeddings": 8192,
14
+ "model_type": "transformer",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 32,
17
+ "num_hidden_layers": 32,
18
+ "num_kv_heads": null,
19
+ "pad_token_id": 2,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false,
22
+ "attn_impl": "parallel_attn"
23
+ }
configs/vanilla_transformer_340M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "parallel_attn"
19
+ }
configs/vanilla_transformer_7B.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_ratio": 4,
9
+ "hidden_size": 4096,
10
+ "initializer_range": 0.006,
11
+ "intermediate_size": 14336,
12
+ "model_type": "transformer",
13
+ "norm_eps": 1e-06,
14
+ "num_heads": 32,
15
+ "num_hidden_layers": 32,
16
+ "num_kv_heads": 8,
17
+ "rope_theta": 10000.0,
18
+ "tie_word_embeddings": false,
19
+ "use_cache": true,
20
+ "window_size": null,
21
+ "attn_impl": "parallel_attn"
22
+ }
download_checkpoint.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from huggingface_hub import HfApi, HfFolder, snapshot_download
4
+
5
+ def main(args):
6
+ api = HfApi()
7
+ token = HfFolder.get_token()
8
+ experiment_checkpoint_folder = os.path.join(args.experiment_checkpoint_folder, "checkpoint")
9
+ os.makedirs(
10
+ experiment_checkpoint_folder,
11
+ exist_ok=True
12
+ )
13
+
14
+ snapshot_download(
15
+ repo_id=args.repo_id,
16
+ token=token,
17
+ local_dir=experiment_checkpoint_folder,
18
+ )
19
+
20
+ if __name__ == "__main__":
21
+ parser = argparse.ArgumentParser(description="Download a checkpoint from Hugging Face Hub.")
22
+ parser.add_argument(
23
+ "--repo_id",
24
+ type=str,
25
+ required=True,
26
+ help="The repository ID on Hugging Face Hub.",
27
+ )
28
+ parser.add_argument(
29
+ "--experiment_checkpoint_folder",
30
+ type=str,
31
+ required=True,
32
+ help="The local directory to save the downloaded checkpoint.",
33
+ )
34
+ args = parser.parse_args()
35
+ main(args)
fla/layers/abc.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+
13
+ from fla.modules import FusedRMSNormGated, RMSNorm, RotaryEmbedding, ShortConvolution
14
+ from fla.modules.activations import swiglu, swish
15
+ from fla.ops.abc.chunk import chunk_abc
16
+
17
+ if TYPE_CHECKING:
18
+ from fla.models.utils import Cache
19
+
20
+
21
+ class ABCAttention(nn.Module):
22
+
23
+ def __init__(
24
+ self,
25
+ hidden_size: int = 1024,
26
+ expand_k: float = 0.5,
27
+ expand_v: float = 1.0,
28
+ num_heads: int = 4,
29
+ use_short_conv: bool = False,
30
+ conv_size: int = 4,
31
+ conv_bias: bool = False,
32
+ num_slots: Optional[int] = None,
33
+ elementwise_affine: Optional[bool] = True,
34
+ norm_eps: float = 1e-5,
35
+ gate_low_rank_dim: int = 16,
36
+ gate_logit_normalizer: int = 16,
37
+ use_rope: bool = True,
38
+ use_input_gate: bool = False,
39
+ use_output_gate: bool = True,
40
+ use_norm: bool = True,
41
+ clamp_min: Optional[float] = -32,
42
+ clamp_max: Optional[float] = 32,
43
+ layer_idx: Optional[int] = None,
44
+ **kwargs
45
+ ) -> ABCAttention:
46
+ super().__init__()
47
+
48
+ self.hidden_size = hidden_size
49
+ self.expand_k = expand_k
50
+ self.expand_v = expand_v
51
+ self.num_heads = num_heads
52
+ self.key_dim = int(self.hidden_size * self.expand_k)
53
+ self.value_dim = int(self.hidden_size * self.expand_v)
54
+ self.head_k_dim = self.key_dim // self.num_heads
55
+ self.head_v_dim = self.value_dim // self.num_heads
56
+
57
+ self.use_short_conv = use_short_conv
58
+ self.conv_size = conv_size
59
+ self.conv_bias = conv_bias
60
+
61
+ self.gate_low_rank_dim = gate_low_rank_dim
62
+ self.gate_logit_normalizer = gate_logit_normalizer
63
+
64
+ self.use_rope = use_rope
65
+ self.use_input_gate = use_input_gate
66
+ self.use_output_gate = use_output_gate
67
+ self.use_norm = use_norm
68
+
69
+ if num_slots is None:
70
+ num_slots = self.head_k_dim
71
+ self.num_slots = num_slots
72
+
73
+ self.norm_eps = norm_eps
74
+
75
+ self.clamp_min = clamp_min
76
+ self.clamp_max = clamp_max
77
+ self.layer_idx = layer_idx
78
+
79
+ if layer_idx is None:
80
+ warnings.warn(
81
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
82
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
83
+ "when creating this class."
84
+ )
85
+
86
+ self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
87
+ self.k_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
88
+ self.v_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
89
+
90
+ if use_output_gate:
91
+ self.g_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
92
+ self.s_proj = nn.Linear(self.hidden_size, self.num_heads * self.num_slots, bias=False)
93
+ self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
94
+
95
+ if use_short_conv:
96
+ self.conv_size = conv_size
97
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
98
+ self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
99
+ self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu')
100
+
101
+ if self.use_norm:
102
+ if self.use_output_gate:
103
+ self.g_norm = FusedRMSNormGated(
104
+ hidden_size=self.head_v_dim,
105
+ elementwise_affine=elementwise_affine,
106
+ eps=norm_eps
107
+ )
108
+ else:
109
+ self.g_norm = RMSNorm(
110
+ hidden_size=self.head_v_dim,
111
+ elementwise_affine=elementwise_affine,
112
+ eps=norm_eps
113
+ )
114
+
115
+ if self.use_rope:
116
+ self.rotary = RotaryEmbedding(self.head_k_dim)
117
+
118
+ def forward(
119
+ self,
120
+ hidden_states: torch.Tensor,
121
+ attention_mask: Optional[torch.Tensor] = None,
122
+ past_key_values: Optional[Cache] = None,
123
+ use_cache: Optional[bool] = False,
124
+ output_attentions: Optional[bool] = False,
125
+ **kwargs
126
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
127
+ if attention_mask is not None:
128
+ assert len(attention_mask.shape) == 2, (
129
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
130
+ "for padding purposes (0 indicating padding). "
131
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
132
+ )
133
+
134
+ last_state = None
135
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
136
+ last_state = past_key_values[self.layer_idx]
137
+
138
+ cu_seqlens = kwargs.get('cu_seqlens', None)
139
+ if cu_seqlens is not None:
140
+ raise NotImplementedError("Training with cu_seqlens is not supported yet for ABCAttention")
141
+ if self.use_short_conv:
142
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
143
+ if last_state is not None:
144
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
145
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
146
+ q, conv_state_q = self.q_conv1d(
147
+ x=self.q_proj(hidden_states),
148
+ mask=conv_mask,
149
+ cache=conv_state_q,
150
+ output_final_state=use_cache,
151
+ cu_seqlens=cu_seqlens
152
+ )
153
+ k, conv_state_k = self.k_conv1d(
154
+ x=self.k_proj(hidden_states),
155
+ mask=conv_mask,
156
+ cache=conv_state_k,
157
+ output_final_state=use_cache,
158
+ cu_seqlens=cu_seqlens
159
+ )
160
+ v, conv_state_v = self.v_conv1d(
161
+ x=self.v_proj(hidden_states),
162
+ mask=conv_mask,
163
+ cache=conv_state_v,
164
+ output_final_state=use_cache,
165
+ cu_seqlens=cu_seqlens
166
+ )
167
+ else:
168
+ q = self.q_proj(hidden_states)
169
+ k = self.k_proj(hidden_states)
170
+ v = self.v_proj(hidden_states)
171
+
172
+ if self.use_input_gate:
173
+ q, k, v = map(lambda x: swish(x), (q, k, v))
174
+ # dealing with left-padding
175
+ if attention_mask is not None:
176
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
177
+
178
+ q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k))
179
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim)
180
+ if self.use_rope:
181
+ seqlen_offset = 0
182
+ if past_key_values is not None:
183
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
184
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset)
185
+
186
+ s = rearrange(self.s_proj(hidden_states), '... (h m) -> ... h m', m=self.num_slots)
187
+ s = s.clamp_(self.clamp_min, self.clamp_max)
188
+
189
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
190
+ o, recurrent_state = chunk_abc(
191
+ q=q,
192
+ k=k,
193
+ v=v,
194
+ s=s,
195
+ initial_state=recurrent_state,
196
+ output_final_state=use_cache,
197
+ head_first=False
198
+ )
199
+ if past_key_values is not None:
200
+ past_key_values.update(
201
+ recurrent_state=recurrent_state,
202
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
203
+ layer_idx=self.layer_idx,
204
+ offset=q.shape[1]
205
+ )
206
+
207
+ if self.use_norm and not self.use_output_gate:
208
+ o = self.g_norm(o)
209
+ elif self.use_output_gate:
210
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim)
211
+ o = self.g_norm(o, g) if self.use_norm else swiglu(g, o)
212
+ o = rearrange(o, '... h d -> ... (h d)')
213
+ o = self.o_proj(o)
214
+
215
+ return o, None, past_key_values
216
+
217
+ def state_size(self, seq_len: int = 2048):
218
+ return 2 * self.num_slots * self.hidden_size
fla/layers/based.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ """
5
+ Linear attention in Based.
6
+ https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+
13
+ from fla.modules.feature_map import TaylorFeatureMap
14
+ from fla.ops.based import parallel_based
15
+ from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn
16
+
17
+
18
+ class BasedLinearAttention(nn.Module):
19
+
20
+ def __init__(
21
+ self,
22
+ hidden_size: int,
23
+ feature_dim: int = 16,
24
+ num_key_value_heads: int = 12,
25
+ num_heads: int = 12,
26
+ feature_name: str = "taylor_exp",
27
+ eps: float = 1e-12,
28
+ causal: bool = True,
29
+ mode: str = "parallel",
30
+ ):
31
+ super().__init__()
32
+
33
+ self.hidden_size = hidden_size
34
+ self.mode = mode
35
+ self.feature_name = feature_name
36
+ self.feature_dim = feature_dim
37
+ self.num_key_value_heads = num_key_value_heads
38
+ self.num_heads = num_heads
39
+ self.head_dim = self.hidden_size // self.num_key_value_heads
40
+ assert self.hidden_size % self.head_dim == 0
41
+ self.causal = causal
42
+
43
+ self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
44
+ self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
45
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
46
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
47
+ self.dropout = nn.Identity()
48
+ self.feature_map = TaylorFeatureMap(feature_dim)
49
+ self.eps = eps
50
+
51
+ def forward(self, hidden_states: torch.Tensor, **kwargs):
52
+ mode = self.mode
53
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
54
+ q, k, v = map(lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_dim), [q, k, v])
55
+ if mode == "fused_chunk":
56
+ q, k = self.feature_map(q), self.feature_map(k)
57
+ o, _ = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1, head_first=False)
58
+ elif mode == 'chunk':
59
+ q, k = self.feature_map(q), self.feature_map(k)
60
+ o, _ = chunk_linear_attn(q, k, v, normalize=True, scale=1, head_first=False)
61
+ elif mode == 'parallel':
62
+ assert q.shape[-1] <= 128
63
+ o = parallel_based(q, k, v, scale=1, use_norm=True, head_first=False)
64
+ o = rearrange(o, 'b t h d -> b t (h d)')
65
+ o = self.o_proj(o)
66
+ o = self.dropout(o)
67
+ return o
68
+
69
+ # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
70
+
71
+ def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs):
72
+ """
73
+ x (torch.Tensor): tensor of shape (b, d, t)
74
+ y (torch.Tensor): tensor of shape (b, d, t)
75
+ """
76
+ # hidden_states = hidden_states.transpose(1, 2)
77
+ b, t, _ = hidden_states.size()
78
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
79
+
80
+ q = q.view(b, t, self.num_heads, self.feature_dim).transpose(1, 2)
81
+ k = k.view(b, t, self.num_key_value_heads, self.feature_dim).transpose(1, 2)
82
+ v = v.view(b, t, self.num_key_value_heads, self.head_dim).transpose(1, 2)
83
+
84
+ # Linear attention
85
+ q, k = self.feature_map(q), self.feature_map(k)
86
+ q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
87
+
88
+ # Compute attention
89
+ if self.causal:
90
+ y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps))
91
+ else:
92
+ y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps))
93
+ y = rearrange(y, 'b h t d -> b t (h d)')
94
+ y = self.o_proj(y.to(hidden_states.dtype))
95
+ y = self.dropout(y)
96
+ return y.to(hidden_states.dtype)
fla/layers/delta_net.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange
11
+ from torch.nn import functional as F
12
+
13
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
14
+ from fla.ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule
15
+
16
+ if TYPE_CHECKING:
17
+ from transformers.processing_utils import Unpack
18
+
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ def elu_p1(x):
23
+ return (F.elu(x, 1., False) + 1.).to(x)
24
+
25
+
26
+ def sum_norm(x):
27
+ return (x / x.sum(-1, keepdim=True)).to(x)
28
+
29
+
30
+ class DeltaNet(nn.Module):
31
+ r"""
32
+ The layer implementaion for [Parallelizing Linear Transformers with the Delta Rule over Sequence Length](https://arxiv.org/abs/2406.06484). # noqa:
33
+ DeltaNet was originally proposed in [Linear Transformers Are Secretly Fast Weight Programmers](https://arxiv.org/abs/2102.11174). # noqa
34
+
35
+ Args:
36
+ mode (str, Optional):
37
+ Which DeltaNet kernel to use.
38
+ Currently available: `chunk`, `fused_recurrent`, and `fused_chunk`.
39
+ Default: `chunk`.
40
+ hidden_size (int, Optional):
41
+ The hidden size of the input. Default: 1024.
42
+ expand_k (float, Optional):
43
+ The expansion ratio for the key dim. Default: 1.0.
44
+ expand_v (float, Optional):
45
+ The expansion ratio for the value dim. Default: 1.0.
46
+ num_heads (int, Optional):
47
+ The number of heads. Default: 4.
48
+ use_beta (bool, Optional):
49
+ Whether to use beta. Default: `True`.
50
+ use_gate (bool, Optional):
51
+ Whether to use output gate. Default: `False`.
52
+ use_short_conv (bool, Optional):
53
+ Whether to use short convolutions. Default: `True`.
54
+ conv_size (int, Optional):
55
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
56
+ conv_bias (bool, Optional):
57
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
58
+ allow_neg_eigval (bool, Optional):
59
+ Allow negative eigenvalues. Default: `False`. If set to `True`, the beta will be multiplied by 2.
60
+ See reference: [Unlocking State-Tracking in Linear RNNs Through Negative Eigenvalues](https://arxiv.org/abs/2411.12537)
61
+ layer_idx (int, Optional):
62
+ The index of the layer. Default: None.
63
+ norm_eps (float, Optional):
64
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
65
+ qk_activation (str, Optional):
66
+ The activation function for the query and key. Default: `silu`.
67
+ qk_norm (str, Optional):
68
+ The normalization method for the query and key. Default: `l2`.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ mode: str = 'chunk',
74
+ d_model: int = None,
75
+ hidden_size: int = 1024,
76
+ expand_k: float = 1.0,
77
+ expand_v: float = 1.0,
78
+ num_heads: int = 4,
79
+ use_beta: bool = True,
80
+ use_gate: bool = False,
81
+ use_short_conv: bool = True,
82
+ conv_size: int = 4,
83
+ conv_bias: bool = False,
84
+ allow_neg_eigval: bool = False,
85
+ layer_idx: int = None,
86
+ qk_activation: str = 'silu',
87
+ qk_norm: str = 'l2',
88
+ norm_eps: float = 1e-5,
89
+ **kwargs
90
+ ) -> DeltaNet:
91
+ super().__init__()
92
+
93
+ self.mode = mode
94
+ self.qk_activation = qk_activation
95
+ self.qk_norm = qk_norm
96
+
97
+ assert self.qk_activation in ['silu', 'relu', 'elu', 'identity']
98
+ assert self.qk_norm in ['l2', 'sum']
99
+
100
+ if d_model is not None:
101
+ hidden_size = d_model
102
+ self.hidden_size = hidden_size
103
+ self.expand_k = expand_k
104
+ self.expand_v = expand_v
105
+ self.num_heads = num_heads
106
+ self.use_gate = use_gate
107
+ self.use_short_conv = use_short_conv
108
+ self.conv_size = conv_size
109
+ self.conv_bias = conv_bias
110
+ self.allow_neg_eigval = allow_neg_eigval
111
+
112
+ self.key_dim = int(hidden_size * expand_k)
113
+ self.value_dim = int(hidden_size * expand_v)
114
+ self.head_k_dim = self.key_dim // num_heads
115
+ self.head_v_dim = self.value_dim // num_heads
116
+ self.layer_idx = layer_idx
117
+
118
+ self.silu = nn.SiLU()
119
+ if mode == 'fused_chunk':
120
+ raise NotImplementedError("fused_chunk_delta_rule is now deprecated. Please use `chunk_delta_rule` instead.")
121
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
122
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
123
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
124
+
125
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
126
+ self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
127
+ self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
128
+
129
+ self.use_beta = use_beta
130
+ if self.use_beta:
131
+ self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
132
+ if use_short_conv:
133
+ self.conv_size = conv_size
134
+ self.q_conv1d = ShortConvolution(
135
+ hidden_size=self.key_dim,
136
+ kernel_size=conv_size,
137
+ activation='silu' if qk_activation == 'silu' else None
138
+ )
139
+ self.k_conv1d = ShortConvolution(
140
+ hidden_size=self.key_dim,
141
+ kernel_size=conv_size,
142
+ activation='silu' if qk_activation == 'silu' else None
143
+ )
144
+ self.v_conv1d = ShortConvolution(
145
+ hidden_size=self.value_dim,
146
+ kernel_size=conv_size,
147
+ activation='silu'
148
+ )
149
+ else:
150
+ raise UserWarning(
151
+ "ShortConvolution is crucial to the performance. "
152
+ "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing."
153
+ )
154
+ if use_gate:
155
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
156
+ self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps)
157
+ else:
158
+ self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
159
+
160
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
161
+
162
+ def forward(
163
+ self,
164
+ hidden_states: torch.Tensor,
165
+ attention_mask: Optional[torch.Tensor] = None,
166
+ past_key_values: Optional[Cache] = None,
167
+ use_cache: Optional[bool] = False,
168
+ output_attentions: Optional[bool] = False,
169
+ **kwargs: Unpack[Dict]
170
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
171
+ if attention_mask is not None:
172
+ assert len(attention_mask.shape) == 2, (
173
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
174
+ "for padding purposes (0 indicating padding). "
175
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
176
+ )
177
+
178
+ # change to inference mode.
179
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
180
+
181
+ last_state = None
182
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
183
+ last_state = past_key_values[self.layer_idx]
184
+
185
+ cu_seqlens = kwargs.get('cu_seqlens', None)
186
+ if self.use_short_conv:
187
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
188
+ if last_state is not None:
189
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
190
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
191
+ q, conv_state_q = self.q_conv1d(
192
+ x=self.q_proj(hidden_states),
193
+ mask=conv_mask,
194
+ cache=conv_state_q,
195
+ output_final_state=use_cache,
196
+ cu_seqlens=cu_seqlens
197
+ )
198
+ k, conv_state_k = self.k_conv1d(
199
+ x=self.k_proj(hidden_states),
200
+ mask=conv_mask,
201
+ cache=conv_state_k,
202
+ output_final_state=use_cache,
203
+ cu_seqlens=cu_seqlens
204
+ )
205
+ v, conv_state_v = self.v_conv1d(
206
+ x=self.v_proj(hidden_states),
207
+ mask=conv_mask,
208
+ cache=conv_state_v,
209
+ output_final_state=use_cache,
210
+ cu_seqlens=cu_seqlens
211
+ )
212
+ else:
213
+ q = self.q_proj(hidden_states)
214
+ k = self.k_proj(hidden_states)
215
+ if self.qk_activation == 'silu':
216
+ q, k = self.silu(q), self.silu(k)
217
+ v = self.silu(self.v_proj(hidden_states))
218
+
219
+ q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k))
220
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim)
221
+ if self.qk_activation != 'silu':
222
+ if self.qk_activation == 'relu':
223
+ q, k = q.relu(), k.relu()
224
+ elif self.qk_activation == 'elu':
225
+ q, k = elu_p1(q), elu_p1(k)
226
+ elif self.qk_activation == 'identity':
227
+ pass
228
+ else:
229
+ raise NotImplementedError
230
+
231
+ if self.qk_norm == 'sum':
232
+ q = sum_norm(q).to(q)
233
+ k = sum_norm(k).to(k)
234
+
235
+ if self.use_beta:
236
+ beta = self.b_proj(hidden_states).sigmoid()
237
+ else:
238
+ beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2])
239
+
240
+ if self.allow_neg_eigval:
241
+ beta = beta * 2.
242
+
243
+ # dealing with padding
244
+ if attention_mask is not None:
245
+ beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None])
246
+
247
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
248
+ if mode == 'fused_recurrent':
249
+ o, recurrent_state = fused_recurrent_delta_rule(
250
+ q=q,
251
+ k=k,
252
+ v=v,
253
+ beta=beta,
254
+ initial_state=recurrent_state,
255
+ output_final_state=use_cache,
256
+ cu_seqlens=cu_seqlens,
257
+ head_first=False,
258
+ use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
259
+ )
260
+ elif mode == 'chunk':
261
+ o, recurrent_state = chunk_delta_rule(
262
+ q=q,
263
+ k=k,
264
+ v=v,
265
+ beta=beta,
266
+ initial_state=recurrent_state,
267
+ output_final_state=use_cache,
268
+ cu_seqlens=cu_seqlens,
269
+ head_first=False,
270
+ use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
271
+ )
272
+ else:
273
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
274
+
275
+ if past_key_values is not None:
276
+ past_key_values.update(
277
+ recurrent_state=recurrent_state,
278
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
279
+ layer_idx=self.layer_idx,
280
+ offset=q.shape[1]
281
+ )
282
+
283
+ if self.use_gate:
284
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim)
285
+ o = self.o_norm(o, g)
286
+ else:
287
+ o = self.o_norm(o)
288
+ o = rearrange(o, 'b t h d -> b t (h d)')
289
+ o = self.o_proj(o)
290
+
291
+ return o, None, past_key_values
fla/layers/forgetting_attn.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from einops import rearrange
13
+ from transformers.utils import logging
14
+
15
+ from fla.modules import GroupNorm
16
+ from fla.ops.forgetting_attn.parallel import parallel_forgetting_attn
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class ForgettingAttention(nn.Module):
26
+
27
+ def __init__(
28
+ self,
29
+ hidden_size: int = 2048,
30
+ num_heads: int = 32,
31
+ num_kv_heads: Optional[int] = None,
32
+ qkv_bias: bool = False,
33
+ qk_norm: bool = False,
34
+ window_size: Optional[int] = None,
35
+ use_output_gate: bool = False,
36
+ layer_idx: int = None
37
+ ):
38
+ super().__init__()
39
+
40
+ self.hidden_size = hidden_size
41
+ self.num_heads = num_heads
42
+ if num_kv_heads is None:
43
+ self.num_kv_heads = self.num_heads
44
+ else:
45
+ self.num_kv_heads = num_kv_heads
46
+ self.num_kv_groups = num_heads // self.num_kv_heads
47
+ self.head_dim = self.hidden_size // self.num_heads
48
+ self.kv_dim = self.num_kv_heads * self.head_dim
49
+ self.qkv_bias = qkv_bias
50
+ self.qk_norm = qk_norm
51
+
52
+ self.window_size = window_size
53
+ self.use_output_gate = use_output_gate
54
+ self.layer_idx = layer_idx
55
+
56
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias)
57
+ self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
58
+ self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
59
+ self.f_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)
60
+
61
+ if use_output_gate:
62
+ self.g_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
63
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
64
+
65
+ if qk_norm:
66
+ self.q_norm = GroupNorm(
67
+ num_groups=self.num_heads,
68
+ hidden_size=self.hidden_size,
69
+ is_rms_norm=True,
70
+ )
71
+ self.k_norm = GroupNorm(
72
+ num_groups=self.num_kv_heads,
73
+ hidden_size=self.kv_dim,
74
+ is_rms_norm=True,
75
+ )
76
+
77
+ def forward(
78
+ self,
79
+ hidden_states: torch.Tensor,
80
+ attention_mask: Optional[torch.LongTensor] = None,
81
+ past_key_values: Optional[Cache] = None,
82
+ output_attentions: bool = False,
83
+ use_cache: bool = False,
84
+ **kwargs,
85
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
86
+ if attention_mask is not None:
87
+ assert len(attention_mask.shape) == 2, (
88
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
89
+ "for padding purposes (0 indicating padding). "
90
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
91
+ )
92
+
93
+ cu_seqlens = kwargs.get('cu_seqlens', None)
94
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
95
+ f = F.logsigmoid(self.f_proj(hidden_states).float())
96
+ if self.qk_norm:
97
+ q, k = self.q_norm(q), self.k_norm(k)
98
+
99
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
100
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
101
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
102
+
103
+ o = parallel_forgetting_attn(q, k, v, f, cu_seqlens=cu_seqlens)
104
+ o = rearrange(o, '... h d -> ... (h d)')
105
+ if self.use_output_gate:
106
+ o = self.g_proj(hidden_states).sigmoid() * o
107
+ o = self.o_proj(o)
108
+
109
+ return o, None, past_key_values
fla/layers/gated_deltanet.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import math
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+ from torch.nn import functional as F
13
+
14
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
15
+ from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
16
+
17
+ if TYPE_CHECKING:
18
+ from transformers.processing_utils import Unpack
19
+
20
+ from fla.models.utils import Cache
21
+
22
+
23
+ @torch.compile
24
+ def elu_p1(x):
25
+ return (F.elu(x, 1., False) + 1.).to(x)
26
+
27
+
28
+ @torch.compile
29
+ def sum_norm(x):
30
+ return (x / x.sum(-1, keepdim=True)).to(x)
31
+
32
+
33
+ class GatedDeltaNet(nn.Module):
34
+ """
35
+ The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464). # noqa
36
+
37
+ Similar to Mamba2, each layer contains around 6*hidden_size*hidden_size parameters.
38
+
39
+ Parameter alloation when use_gate=True:
40
+ - 0.75 * hidden_size * hidden_size for the q_proj and k_proj each
41
+ - 1.5 * hidden_size * hidden_size for the v_proj, g_proj and o_proj each
42
+ - Others are ignorably small.
43
+ - In total = 0.75 * 2 + 1.5 * 3 = 6 * hidden_size * hidden_size
44
+ NOTE: num_heads * head_dim = 0.75 * hidden_size, please make sure to set the correct num_heads and head_dim.
45
+
46
+ Parameter allocation when use_gate=False:
47
+ - 1 * hidden_size * hidden_size for the q_proj and k_proj each
48
+ - 2 * hidden_size * hidden_size for the v_proj and o_proj each
49
+ - Others are ignorably small.
50
+ - In total = 1 * 2 + 2 * 2 = 6 * hidden_size * hidden_size
51
+
52
+ Args:
53
+ hidden_size (int, Optional):
54
+ The hidden size of the input. Default: 2048.
55
+ expand_v (float, Optional):
56
+ The expansion ratio for the value dim. Default: 2.0.
57
+ head_dim (int, Optional):
58
+ The dimension of each head. Default: 256.
59
+ num_heads (int, Optional):
60
+ The number of heads. Default: 4.
61
+ mode (str, Optional):
62
+ Which Gated DeltaNet kernel to use.
63
+ Currently available: `chunk` and `fused_recurrent`.
64
+ Default: `chunk`.
65
+ use_beta (bool, Optional):
66
+ Whether to use beta. Default: `True`.
67
+ use_gate (bool, Optional):
68
+ Whether to use output gate. Default: `True`.
69
+ use_short_conv (bool, Optional):
70
+ Whether to use short convolutions. Default: `True`.
71
+ conv_size (int, Optional):
72
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
73
+ conv_bias (bool, Optional):
74
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
75
+ layer_idx (int, Optional):
76
+ The index of the layer. Default: None.
77
+ norm_eps (float, Optional):
78
+ The epsilon value for the normalization layer. Default: 1e-5.
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ hidden_size: int = 2048,
84
+ expand_v: float = 2,
85
+ head_dim: int = 256,
86
+ num_heads: int = 6,
87
+ mode: str = 'chunk',
88
+ use_gate: bool = True,
89
+ use_short_conv: bool = True,
90
+ conv_size: int = 4,
91
+ conv_bias: bool = False,
92
+ layer_idx: int = None,
93
+ norm_eps: float = 1e-5,
94
+ **kwargs
95
+ ) -> GatedDeltaNet:
96
+ super().__init__()
97
+
98
+ self.mode = mode
99
+
100
+ self.hidden_size = hidden_size
101
+ self.expand_v = expand_v
102
+
103
+ self.use_gate = use_gate
104
+ self.use_short_conv = use_short_conv
105
+ self.conv_size = conv_size
106
+ self.conv_bias = conv_bias
107
+
108
+ self.head_dim = head_dim
109
+ self.num_heads = num_heads
110
+
111
+ self.key_dim = int(self.num_heads * self.head_dim)
112
+ self.value_dim = int(self.key_dim * self.expand_v)
113
+ self.head_k_dim = head_dim
114
+ self.head_v_dim = int(head_dim * self.expand_v)
115
+ self.layer_idx = layer_idx
116
+
117
+ # Consistency check: Ensure expand_v produces integer values
118
+ if not math.isclose(self.key_dim * expand_v, self.value_dim, rel_tol=1e-5):
119
+ raise ValueError(
120
+ f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. "
121
+ f"Resulting value_dim would be {self.key_dim * expand_v}, which is invalid for nn.Linear."
122
+ )
123
+ if not math.isclose(head_dim * expand_v, self.head_v_dim, rel_tol=1e-5):
124
+ raise ValueError(
125
+ f"expand_v={expand_v} does not produce an integer value when multiplied by head_dim={head_dim}. "
126
+ f"Resulting head_v_dim would be {head_dim * expand_v}, which is invalid for FusedRMSNormGated."
127
+ )
128
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
129
+
130
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
131
+ self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
132
+ self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
133
+ self.a_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
134
+ self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
135
+
136
+ A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16)
137
+ self.A_log = nn.Parameter(torch.log(A))
138
+ self.A_log._no_weight_decay = True
139
+ # hard coded for now
140
+ dt_min = 0.001
141
+ dt_max = 0.1
142
+ dt_init_floor = 1e-4
143
+ dt = torch.exp(
144
+ torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min))
145
+ + math.log(dt_min)
146
+ )
147
+ dt = torch.clamp(dt, min=dt_init_floor)
148
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
149
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
150
+ self.dt_bias = nn.Parameter(inv_dt)
151
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
152
+ # name.endswith("bias") in param_grouping.py
153
+ self.dt_bias._no_weight_decay = True
154
+
155
+ if use_short_conv:
156
+ self.conv_size = conv_size
157
+ self.q_conv1d = ShortConvolution(
158
+ hidden_size=self.key_dim,
159
+ kernel_size=conv_size,
160
+ activation='silu'
161
+ )
162
+ self.k_conv1d = ShortConvolution(
163
+ hidden_size=self.key_dim,
164
+ kernel_size=conv_size,
165
+ activation='silu'
166
+ )
167
+ self.v_conv1d = ShortConvolution(
168
+ hidden_size=self.value_dim,
169
+ kernel_size=conv_size,
170
+ activation='silu'
171
+ )
172
+ else:
173
+ raise UserWarning(
174
+ "ShortConvolution is crucial to the performance. "
175
+ "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing."
176
+ )
177
+ if use_gate:
178
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
179
+ self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps)
180
+ else:
181
+ self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
182
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
183
+
184
+ def forward(
185
+ self,
186
+ hidden_states: torch.Tensor,
187
+ attention_mask: Optional[torch.Tensor] = None,
188
+ past_key_values: Optional[Cache] = None,
189
+ use_cache: Optional[bool] = False,
190
+ output_attentions: Optional[bool] = False,
191
+ **kwargs: Unpack[Dict]
192
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
193
+ if attention_mask is not None:
194
+ assert len(attention_mask.shape) == 2, (
195
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
196
+ "for padding purposes (0 indicating padding). "
197
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
198
+ )
199
+
200
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
201
+ if self.training:
202
+ assert mode == 'chunk', "Only chunk mode is supported in training."
203
+
204
+ last_state = None
205
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
206
+ last_state = past_key_values[self.layer_idx]
207
+
208
+ cu_seqlens = kwargs.get('cu_seqlens', None)
209
+ if self.use_short_conv:
210
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
211
+ if last_state is not None:
212
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
213
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
214
+ q, conv_state_q = self.q_conv1d(
215
+ x=self.q_proj(hidden_states),
216
+ mask=conv_mask,
217
+ cache=conv_state_q,
218
+ output_final_state=use_cache,
219
+ cu_seqlens=cu_seqlens
220
+ )
221
+ k, conv_state_k = self.k_conv1d(
222
+ x=self.k_proj(hidden_states),
223
+ mask=conv_mask,
224
+ cache=conv_state_k,
225
+ output_final_state=use_cache,
226
+ cu_seqlens=cu_seqlens
227
+ )
228
+ v, conv_state_v = self.v_conv1d(
229
+ x=self.v_proj(hidden_states),
230
+ mask=conv_mask,
231
+ cache=conv_state_v,
232
+ output_final_state=use_cache,
233
+ cu_seqlens=cu_seqlens
234
+ )
235
+ else:
236
+ q = F.silu(self.q_proj(hidden_states))
237
+ k = F.silu(self.k_proj(hidden_states))
238
+ v = F.silu(self.v_proj(hidden_states))
239
+
240
+ q, k = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_k_dim), (q, k))
241
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
242
+ beta = self.b_proj(hidden_states).sigmoid()
243
+ g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias)
244
+
245
+ # dealing with padding
246
+ if attention_mask is not None:
247
+ beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None])
248
+ g = g.mul(attention_mask[:, -g.shape[-2]:, None])
249
+
250
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
251
+ if mode == 'chunk':
252
+ o, recurrent_state = chunk_gated_delta_rule(
253
+ q=q,
254
+ k=k,
255
+ v=v,
256
+ g=g,
257
+ beta=beta,
258
+ initial_state=recurrent_state,
259
+ output_final_state=use_cache,
260
+ cu_seqlens=cu_seqlens,
261
+ head_first=False,
262
+ use_qk_l2norm_in_kernel=True
263
+ )
264
+ elif mode == 'fused_recurrent':
265
+ o, recurrent_state = fused_recurrent_gated_delta_rule(
266
+ q=q,
267
+ k=k,
268
+ v=v,
269
+ g=g,
270
+ beta=beta,
271
+ initial_state=recurrent_state,
272
+ output_final_state=use_cache,
273
+ cu_seqlens=cu_seqlens,
274
+ head_first=False,
275
+ use_qk_l2norm_in_kernel=True
276
+ )
277
+ if past_key_values is not None:
278
+ past_key_values.update(
279
+ recurrent_state=recurrent_state,
280
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
281
+ layer_idx=self.layer_idx,
282
+ offset=q.shape[1]
283
+ )
284
+
285
+ if self.use_gate:
286
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim)
287
+ o = self.o_norm(o, g)
288
+ else:
289
+ o = self.o_norm(o)
290
+ o = rearrange(o, 'b t h d -> b t (h d)')
291
+ o = self.o_proj(o)
292
+
293
+ return o, None, past_key_values
fla/layers/gla.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange, repeat
13
+
14
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
15
+ from fla.modules.activations import ACT2FN
16
+ from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
17
+
18
+ if TYPE_CHECKING:
19
+ from transformers.processing_utils import Unpack
20
+
21
+ from fla.models.utils import Cache
22
+
23
+
24
+ class GatedLinearAttention(nn.Module):
25
+ r"""
26
+ The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa
27
+
28
+ Args:
29
+ mode (str, Optional):
30
+ Which GLA kernel to use.
31
+ Currently available: `chunk`, `fused_recurrent`, and `fused_chunk`.
32
+ Default: `chunk`.
33
+ hidden_size (int, Optional):
34
+ The hidden size of the input. Default: 1024.
35
+ expand_k (float, Optional):
36
+ The expansion ratio for the key dim. Default: 0.5.
37
+ expand_v (float, Optional):
38
+ The expansion ratio for the value dim. Default: 1.0.
39
+ num_heads (int, Optional):
40
+ The number of heads. Default: 4.
41
+ num_kv_heads (int, Optional):
42
+ The number of key/value heads, used for MQA. Default: None.
43
+ feature_map (str, Optional):
44
+ Feature map function applied to queries/keys. Default: None.
45
+ use_short_conv (bool, Optional):
46
+ Whether to use short convolutions. Default: `False`.
47
+ conv_size (int, Optional):
48
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
49
+ conv_bias (bool, Optional):
50
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
51
+ use_output_gate (bool, Optional):
52
+ Whether to use output gate. Default: `True`.
53
+ gate_fn (str, Optional):
54
+ The activation function for the output gate. Default: `swish`.
55
+ elementwise_affine (bool, Optional):
56
+ If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
57
+ norm_eps (float, Optional):
58
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
59
+ gate_logit_normalizer (int, Optional):
60
+ The normalizer for the gate logits, appied after `logsigmoid`. Default: 16.
61
+ gate_low_rank_dim (int, Optional):
62
+ The low rank dim for the gate projection. Default: 16.
63
+ clamp_min (float, Optional):
64
+ The minimum value for the gate logits. Default: None.
65
+ fuse_norm (bool, Optional):
66
+ Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
67
+ layer_idx (int, Optional):
68
+ The index of the layer. Default: None.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ mode: str = 'chunk',
74
+ hidden_size: int = 1024,
75
+ expand_k: float = 0.5,
76
+ expand_v: float = 1.0,
77
+ num_heads: int = 4,
78
+ num_kv_heads: Optional[int] = None,
79
+ feature_map: Optional[str] = None,
80
+ use_short_conv: bool = False,
81
+ conv_size: int = 4,
82
+ conv_bias: bool = False,
83
+ use_output_gate: bool = True,
84
+ gate_fn: str = 'swish',
85
+ elementwise_affine: Optional[bool] = True,
86
+ norm_eps: float = 1e-5,
87
+ gate_logit_normalizer: int = 16,
88
+ gate_low_rank_dim: int = 16,
89
+ clamp_min: Optional[float] = None,
90
+ fuse_norm: bool = True,
91
+ layer_idx: int = None,
92
+ ) -> GatedLinearAttention:
93
+ super().__init__()
94
+
95
+ self.mode = mode
96
+ self.hidden_size = hidden_size
97
+ self.expand_k = expand_k
98
+ self.expand_v = expand_v
99
+ self.num_heads = num_heads
100
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
101
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
102
+ self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
103
+
104
+ self.use_short_conv = use_short_conv
105
+ self.conv_size = conv_size
106
+ self.conv_bias = conv_bias
107
+ self.use_output_gate = use_output_gate
108
+
109
+ self.key_dim = int(hidden_size * expand_k)
110
+ self.value_dim = int(hidden_size * expand_v)
111
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
112
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
113
+ self.clamp_min = clamp_min
114
+ self.layer_idx = layer_idx
115
+
116
+ assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
117
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
118
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
119
+
120
+ self.head_k_dim = self.key_dim // num_heads
121
+ self.head_v_dim = self.value_dim // num_heads
122
+
123
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
124
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
125
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
126
+ if self.use_output_gate:
127
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
128
+
129
+ if use_short_conv:
130
+ self.conv_size = conv_size
131
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
132
+ self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
133
+ self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
134
+
135
+ self.gk_proj = nn.Sequential(nn.Linear(hidden_size, gate_low_rank_dim, bias=False),
136
+ nn.Linear(gate_low_rank_dim, self.key_dim_per_group, bias=True))
137
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
138
+
139
+ if gate_fn == 'swish' and fuse_norm and use_output_gate:
140
+ self.g_norm_swish_gate = FusedRMSNormGated(
141
+ hidden_size=self.head_v_dim,
142
+ elementwise_affine=elementwise_affine,
143
+ eps=norm_eps
144
+ )
145
+ self.fuse_norm_and_gate = True
146
+ else:
147
+ self.fuse_norm_and_gate = False
148
+ self.g_norm = RMSNorm(
149
+ hidden_size=self.head_v_dim,
150
+ elementwise_affine=elementwise_affine,
151
+ eps=norm_eps
152
+ )
153
+ self.gate_fn = ACT2FN[gate_fn]
154
+
155
+ self.gate_logit_normalizer = gate_logit_normalizer
156
+
157
+ def forward(
158
+ self,
159
+ hidden_states: torch.Tensor,
160
+ attention_mask: Optional[torch.Tensor] = None,
161
+ past_key_values: Optional[Cache] = None,
162
+ use_cache: Optional[bool] = False,
163
+ output_attentions: Optional[bool] = False,
164
+ **kwargs: Unpack[Dict]
165
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
166
+ if attention_mask is not None:
167
+ assert len(attention_mask.shape) == 2, (
168
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
169
+ "for padding purposes (0 indicating padding). "
170
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
171
+ )
172
+
173
+ # launching the triton kernel for just one token will actually be slower
174
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
175
+
176
+ last_state = None
177
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
178
+ last_state = past_key_values[self.layer_idx]
179
+
180
+ cu_seqlens = kwargs.get('cu_seqlens', None)
181
+ if self.use_short_conv:
182
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
183
+ if last_state is not None:
184
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
185
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
186
+ q, conv_state_q = self.q_conv1d(
187
+ x=self.q_proj(hidden_states),
188
+ mask=conv_mask,
189
+ cache=conv_state_q,
190
+ output_final_state=use_cache,
191
+ cu_seqlens=cu_seqlens
192
+ )
193
+ k, conv_state_k = self.k_conv1d(
194
+ x=self.k_proj(hidden_states),
195
+ mask=conv_mask,
196
+ cache=conv_state_k,
197
+ output_final_state=use_cache,
198
+ cu_seqlens=cu_seqlens
199
+ )
200
+ v, conv_state_v = self.v_conv1d(
201
+ x=self.v_proj(hidden_states),
202
+ mask=conv_mask,
203
+ cache=conv_state_v,
204
+ output_final_state=use_cache,
205
+ cu_seqlens=cu_seqlens
206
+ )
207
+ else:
208
+ q = self.q_proj(hidden_states)
209
+ k = self.k_proj(hidden_states)
210
+ v = self.v_proj(hidden_states)
211
+ gk = self.gk_proj(hidden_states)
212
+
213
+ if self.feature_map_fn is not None:
214
+ q, k = map(self.feature_map_fn, (q, k))
215
+ # dealing with left-padding
216
+ if attention_mask is not None:
217
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
218
+ q = rearrange(q, 'b t (h d) -> b t h d', d=self.head_k_dim)
219
+ if self.num_kv_groups > 1:
220
+ k, gk = (repeat(x, 'b t (h d) -> b t (h g) d', g=self.num_kv_groups, d=self.head_k_dim) for x in (k, gk))
221
+ v = repeat(v, 'b t (h d) -> b t (h g) d', g=self.num_kv_groups, d=self.head_v_dim)
222
+ else:
223
+ k, gk = (rearrange(x, 'b t (h d) -> b t h d', d=self.head_k_dim) for x in (k, gk))
224
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
225
+ gk = F.logsigmoid(gk) / self.gate_logit_normalizer
226
+
227
+ if self.clamp_min is not None:
228
+ gk = torch.clamp_min(gk, self.clamp_min)
229
+
230
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
231
+ if mode == 'fused_recurrent':
232
+ o, recurrent_state = fused_recurrent_gla(
233
+ q=q,
234
+ k=k,
235
+ v=v,
236
+ gk=gk,
237
+ initial_state=recurrent_state,
238
+ output_final_state=use_cache,
239
+ cu_seqlens=cu_seqlens,
240
+ head_first=False
241
+ )
242
+ elif mode == 'fused_chunk':
243
+ o, recurrent_state = fused_chunk_gla(
244
+ q=q,
245
+ k=k,
246
+ v=v,
247
+ g=gk,
248
+ initial_state=recurrent_state,
249
+ output_final_state=use_cache,
250
+ head_first=False
251
+ )
252
+ elif mode == 'chunk':
253
+ o, recurrent_state = chunk_gla(
254
+ q=q,
255
+ k=k,
256
+ v=v,
257
+ g=gk,
258
+ initial_state=recurrent_state,
259
+ output_final_state=use_cache,
260
+ cu_seqlens=cu_seqlens,
261
+ head_first=False
262
+ )
263
+ else:
264
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
265
+
266
+ if past_key_values is not None:
267
+ past_key_values.update(
268
+ recurrent_state=recurrent_state,
269
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
270
+ layer_idx=self.layer_idx,
271
+ offset=q.shape[1]
272
+ )
273
+
274
+ if self.use_output_gate:
275
+ g = self.g_proj(hidden_states)
276
+ if self.fuse_norm_and_gate:
277
+ g = rearrange(g, 'b t (h d) -> b t h d', d=self.head_v_dim)
278
+ o = self.g_norm_swish_gate(o, g)
279
+ o = rearrange(o, 'b t h d -> b t (h d)')
280
+ else:
281
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
282
+ o = o * self.gate_fn(g)
283
+ else:
284
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
285
+ o = self.o_proj(o)
286
+
287
+ return o, None, past_key_values
288
+
289
+ def state_size(self, **kwargs) -> int:
290
+ state_size = self.key_dim * self.head_v_dim
291
+ for module in self.children():
292
+ if isinstance(module, ShortConvolution):
293
+ state_size += module.state_size
294
+ return state_size
fla/layers/gsa.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange
13
+
14
+ from fla.modules import RMSNorm, ShortConvolution
15
+ from fla.modules.feature_map import ReLUFeatureMap, SwishFeatureMap, T2RFeatureMap
16
+ from fla.modules.layernorm import rms_norm_linear
17
+ from fla.ops.gsa import chunk_gsa, fused_recurrent_gsa
18
+
19
+ if TYPE_CHECKING:
20
+ from transformers.processing_utils import Unpack
21
+
22
+ from fla.models.utils import Cache
23
+
24
+
25
+ class GatedSlotAttention(nn.Module):
26
+
27
+ def __init__(
28
+ self,
29
+ mode: str = 'chunk',
30
+ hidden_size: int = 1024,
31
+ expand_k: float = 1.,
32
+ expand_v: float = 1.,
33
+ num_heads: int = 4,
34
+ num_kv_heads: Optional[int] = None,
35
+ use_short_conv: bool = False,
36
+ conv_size: int = 4,
37
+ conv_bias: bool = False,
38
+ num_slots: Optional[int] = None,
39
+ elementwise_affine: Optional[bool] = True,
40
+ norm_eps: float = 1e-5,
41
+ gate_logit_normalizer: int = 8,
42
+ feature_map: str = 'swish',
43
+ use_output_gate: bool = False,
44
+ use_norm: bool = True,
45
+ layer_idx: Optional[int] = None,
46
+ scale: Optional[float] = 1.,
47
+ **kwargs
48
+ ) -> GatedSlotAttention:
49
+ super().__init__()
50
+
51
+ self.mode = mode
52
+ self.hidden_size = hidden_size
53
+ self.expand_k = expand_k
54
+ self.expand_v = expand_v
55
+ self.num_heads = num_heads
56
+ self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
57
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
58
+ self.key_dim = int(hidden_size * expand_k)
59
+ self.value_dim = int(hidden_size * expand_v)
60
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
61
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
62
+ self.head_k_dim = self.key_dim // self.num_heads
63
+ self.head_v_dim = self.value_dim // self.num_heads
64
+
65
+ self.use_short_conv = use_short_conv
66
+ self.conv_size = conv_size
67
+ self.conv_bias = conv_bias
68
+
69
+ self.gate_logit_normalizer = gate_logit_normalizer
70
+
71
+ self.use_output_gate = use_output_gate
72
+ self.use_norm = use_norm
73
+ self.scale = scale
74
+
75
+ if num_slots is None:
76
+ num_slots = self.head_k_dim
77
+ self.num_slots = num_slots
78
+
79
+ self.layer_idx = layer_idx
80
+
81
+ if layer_idx is None:
82
+ warnings.warn(
83
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
84
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
85
+ "when creating this class."
86
+ )
87
+
88
+ self.register_module('feature_map', None)
89
+ if feature_map == 'swish':
90
+ self.feature_map = SwishFeatureMap()
91
+ elif feature_map == 'relu':
92
+ self.feature_map = ReLUFeatureMap()
93
+ elif feature_map == 't2r':
94
+ self.feature_map = T2RFeatureMap(self.head_k_dim, self.head_k_dim)
95
+ else:
96
+ raise NotImplementedError(f"Feature map `{feature_map}` is not supported now.")
97
+
98
+ self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
99
+ self.k_proj = nn.Linear(self.hidden_size, self.key_dim_per_group, bias=False)
100
+ self.v_proj = nn.Linear(self.hidden_size, self.value_dim_per_group, bias=False)
101
+ self.f_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.num_slots, bias=False)
102
+
103
+ if use_short_conv:
104
+ self.conv_size = conv_size
105
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
106
+ self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
107
+ self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
108
+
109
+ self.g_norm = RMSNorm(self.hidden_size, elementwise_affine, eps=norm_eps)
110
+ self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
111
+
112
+ def forward(
113
+ self,
114
+ hidden_states: torch.Tensor,
115
+ attention_mask: Optional[torch.Tensor] = None,
116
+ past_key_values: Optional[Cache] = None,
117
+ use_cache: Optional[bool] = False,
118
+ output_attentions: Optional[bool] = False,
119
+ **kwargs: Unpack[Dict]
120
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
121
+ if attention_mask is not None:
122
+ assert len(attention_mask.shape) == 2, (
123
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
124
+ "for padding purposes (0 indicating padding). "
125
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
126
+ )
127
+
128
+ # launching the triton kernel for just one token will actually be slower
129
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
130
+
131
+ last_state = None
132
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
133
+ last_state = past_key_values[self.layer_idx]
134
+
135
+ cu_seqlens = kwargs.get('cu_seqlens', None)
136
+ if self.use_short_conv:
137
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
138
+ if last_state is not None:
139
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
140
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
141
+ q, conv_state_q = self.q_conv1d(
142
+ x=self.q_proj(hidden_states),
143
+ mask=conv_mask,
144
+ cache=conv_state_q,
145
+ output_final_state=use_cache,
146
+ cu_seqlens=cu_seqlens
147
+ )
148
+ k, conv_state_k = self.k_conv1d(
149
+ x=self.k_proj(hidden_states),
150
+ mask=conv_mask,
151
+ cache=conv_state_k,
152
+ output_final_state=use_cache,
153
+ cu_seqlens=cu_seqlens
154
+ )
155
+ v, conv_state_v = self.v_conv1d(
156
+ x=self.v_proj(hidden_states),
157
+ mask=conv_mask,
158
+ cache=conv_state_v,
159
+ output_final_state=use_cache,
160
+ cu_seqlens=cu_seqlens
161
+ )
162
+ else:
163
+ q = self.q_proj(hidden_states)
164
+ k = self.k_proj(hidden_states)
165
+ v = self.v_proj(hidden_states)
166
+ f = self.f_proj(hidden_states)
167
+
168
+ q = rearrange(q, 'b t (h d) -> b t h d', d=self.head_k_dim)
169
+ k = rearrange(k, 'b t (h d) -> b t h d', d=self.head_k_dim)
170
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
171
+ f = rearrange(f, 'b t (h m) -> b t h m', m=self.num_slots)
172
+
173
+ if self.feature_map is not None:
174
+ q, k = map(lambda x: self.feature_map(x), (q, k))
175
+ v = F.silu(v)
176
+
177
+ f = F.logsigmoid(f) / self.gate_logit_normalizer
178
+ s = (1 - f.exp()).to(f.dtype)
179
+ # dealing with left-padding
180
+ if attention_mask is not None:
181
+ s = s.mul_(attention_mask[:, -s.shape[1]:, None, None])
182
+ v = v.mul_(attention_mask[:, -v.shape[1]:, None, None])
183
+
184
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
185
+ if mode == 'fused_recurrent':
186
+ o, recurrent_state = fused_recurrent_gsa(
187
+ q=q,
188
+ k=k,
189
+ v=v,
190
+ s=s,
191
+ g=f,
192
+ initial_state=recurrent_state,
193
+ output_final_state=use_cache,
194
+ scale=self.scale,
195
+ cu_seqlens=cu_seqlens,
196
+ head_first=False
197
+ )
198
+ elif mode == 'chunk':
199
+ o, recurrent_state = chunk_gsa(
200
+ q=q,
201
+ k=k,
202
+ v=v,
203
+ s=s,
204
+ g=f,
205
+ initial_state=recurrent_state,
206
+ output_final_state=use_cache,
207
+ scale=self.scale,
208
+ cu_seqlens=cu_seqlens,
209
+ head_first=False
210
+ )
211
+ else:
212
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
213
+
214
+ if past_key_values is not None:
215
+ past_key_values.update(
216
+ recurrent_state=recurrent_state,
217
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
218
+ layer_idx=self.layer_idx,
219
+ offset=q.shape[1]
220
+ )
221
+
222
+ o = rearrange(o, 'b t h d -> b t (h d)')
223
+ o = rms_norm_linear(F.silu(o), self.g_norm.weight, self.g_norm.bias, self.o_proj.weight, self.o_proj.bias)
224
+ return o, None, past_key_values
225
+
226
+ def state_size(self, *args, **kwargs) -> int:
227
+ return 2 * self.num_slots * self.hidden_size
fla/layers/rebased.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ """
5
+ https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/layers/rebased_fast.py
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Optional
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from einops import rearrange
15
+
16
+ from fla.modules.feature_map import RebasedFeatureMap
17
+ from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn
18
+ from fla.ops.rebased import parallel_rebased
19
+
20
+
21
+ class ReBasedLinearAttention(nn.Module):
22
+
23
+ def __init__(
24
+ self,
25
+ hidden_size: int,
26
+ l_max: int = 2048,
27
+ feature_dim: int = 16,
28
+ num_key_value_heads: int = 16,
29
+ num_heads: int = 16,
30
+ use_gamma: Optional[bool] = True,
31
+ use_beta: Optional[bool] = True,
32
+ normalize: Optional[bool] = True,
33
+ causal: bool = True,
34
+ eps: float = 1e-5,
35
+ mode: str = "parallel",
36
+ layer_idx: Optional[int] = None,
37
+ **kwargs
38
+ ) -> ReBasedLinearAttention:
39
+ super().__init__()
40
+ self.hidden_size = hidden_size
41
+ self.l_max = l_max
42
+ self.mode = mode
43
+ assert self.mode in ["fused_chunk", "parallel", 'chunk']
44
+
45
+ self.feature_dim = feature_dim
46
+ self.num_key_value_heads = num_key_value_heads
47
+ self.num_heads = num_heads
48
+ self.head_dim = self.hidden_size // self.num_key_value_heads
49
+ self.use_gamma = use_gamma
50
+ self.use_beta = use_beta
51
+ self.normalize = normalize
52
+ self.causal = causal
53
+ self.eps = eps
54
+ self.mode = mode
55
+ self.layer_idx = layer_idx
56
+
57
+ self.feature_map = RebasedFeatureMap(self.feature_dim, use_gamma, use_beta, normalize)
58
+ self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
59
+ self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
60
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
61
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
62
+ self.dropout = nn.Identity()
63
+
64
+ def forward(self, hidden_states: torch.Tensor, **kwargs):
65
+ mode = self.mode
66
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
67
+ q, k, v = map(lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_dim), [q, k, v])
68
+ q, k = self.feature_map(q, flatten=(mode != 'parallel')), self.feature_map(k, flatten=(mode != 'parallel'))
69
+ if mode == "fused_chunk":
70
+ o = fused_chunk_linear_attn(
71
+ q=q,
72
+ k=k,
73
+ v=v,
74
+ normalize=True,
75
+ scale=1,
76
+ head_first=False
77
+ )
78
+ elif mode == 'chunk':
79
+ o = chunk_linear_attn(
80
+ q=q,
81
+ k=k,
82
+ v=v,
83
+ normalize=True,
84
+ scale=1,
85
+ head_first=False
86
+ )
87
+ elif mode == 'parallel':
88
+ assert q.shape[-1] <= 128
89
+ o = parallel_rebased(
90
+ q=q,
91
+ k=k,
92
+ v=v,
93
+ eps=self.eps,
94
+ use_scale=True,
95
+ use_normalize=True,
96
+ head_first=False
97
+ )
98
+ o = self.o_proj(o)
99
+ o = self.dropout(o)
100
+ return o
101
+
102
+ # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
103
+ def forward_reference(
104
+ self,
105
+ hidden_states: torch.Tensor,
106
+ filters: torch.Tensor = None,
107
+ *args,
108
+ **kwargs
109
+ ):
110
+ """
111
+ x (torch.Tensor): tensor of shape (b, d, t)
112
+ y (torch.Tensor): tensor of shape (b, d, t)
113
+ """
114
+ b, t, _ = hidden_states.size()
115
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
116
+
117
+ q = q.view(b, t, -1, self.feature_dim).transpose(1, 2)
118
+ k = k.view(b, t, -1, self.feature_dim).transpose(1, 2)
119
+ v = v.view(b, t, -1, self.head_dim).transpose(1, 2)
120
+
121
+ # Linear attention
122
+ q, k = self.feature_map(q), self.feature_map(k)
123
+ q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
124
+
125
+ # Compute attention
126
+ if self.causal:
127
+ y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps))
128
+ else:
129
+ y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps))
130
+ y = rearrange(y, 'b h t d -> b t (h d)')
131
+ y = self.o_proj(y.to(hidden_states.dtype))
132
+ y = self.dropout(y)
133
+ return y.to(hidden_states.dtype)
fla/layers/rwkv7.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange
11
+ from torch.nn import functional as F
12
+
13
+ from fla.layers.rwkv6 import LoRA
14
+ from fla.modules import GroupNorm
15
+ from fla.modules.l2norm import l2_norm
16
+ from fla.ops.rwkv7 import chunk_rwkv7, fused_recurrent_rwkv7
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ class RWKV7Attention(nn.Module):
23
+
24
+ def __init__(
25
+ self,
26
+ mode: str = 'chunk',
27
+ hidden_size: int = 1024,
28
+ head_dim: Optional[int] = 64,
29
+ num_heads: Optional[int] = None,
30
+ decay_low_rank_dim: int = 64,
31
+ gate_low_rank_dim: int = 128,
32
+ a_low_rank_dim: int = 64,
33
+ v_low_rank_dim: int = 16,
34
+ elementwise_affine: Optional[bool] = True,
35
+ norm_eps: float = 1e-5,
36
+ layer_idx: int = None,
37
+ fuse_norm: bool = False,
38
+ value_dim: int = None,
39
+ **kwargs
40
+ ) -> RWKV7Attention:
41
+ super().__init__()
42
+
43
+ self.mode = mode
44
+ assert mode in ['chunk', 'fused_recurrent'], f"Not supported mode `{mode}`."
45
+ self.hidden_size = hidden_size
46
+
47
+ self.key_dim = hidden_size
48
+ self.value_dim = value_dim if value_dim is not None else hidden_size
49
+ if head_dim is None and num_heads is None:
50
+ raise ValueError("Either `head_dim` or `num_heads` must be specified.")
51
+ elif head_dim is not None:
52
+ self.head_dim = head_dim
53
+ self.num_heads = int(hidden_size // head_dim)
54
+ elif num_heads is not None:
55
+ self.head_dim = int(hidden_size // num_heads)
56
+ self.num_heads = num_heads
57
+ self.head_v_dim = int(self.value_dim // self.num_heads)
58
+
59
+ self.decay_low_rank_dim = decay_low_rank_dim
60
+ self.gate_low_rank_dim = gate_low_rank_dim
61
+ self.a_low_rank_dim = a_low_rank_dim
62
+ self.v_low_rank_dim = v_low_rank_dim
63
+ self.layer_idx = layer_idx
64
+ self.fuse_norm = fuse_norm
65
+
66
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
67
+
68
+ self.x_x = nn.Parameter(torch.zeros(6, hidden_size))
69
+
70
+ self.k_k = nn.Parameter(torch.zeros(self.key_dim))
71
+ self.k_a = nn.Parameter(torch.zeros(self.key_dim))
72
+ self.r_k = nn.Parameter(torch.zeros(self.num_heads, self.head_dim))
73
+
74
+ self.r_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
75
+ self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
76
+ self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
77
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
78
+
79
+ self.w_lora = LoRA(hidden_size, self.key_dim, low_rank_dim=decay_low_rank_dim, activation='tanh')
80
+ if self.layer_idx != 0:
81
+ self.v_lora = LoRA(hidden_size, self.value_dim, low_rank_dim=v_low_rank_dim, activation=None)
82
+ self.a_lora = LoRA(hidden_size, self.key_dim, low_rank_dim=a_low_rank_dim, activation=None)
83
+ self.g_lora = LoRA(hidden_size, self.value_dim, low_rank_dim=gate_low_rank_dim, activation='sigmoid', bias=False)
84
+
85
+ if self.fuse_norm:
86
+ self.g_norm = GroupNorm(
87
+ num_groups=self.num_heads,
88
+ hidden_size=self.value_dim,
89
+ elementwise_affine=elementwise_affine,
90
+ eps=self.head_dim*norm_eps,
91
+ bias=True,
92
+ )
93
+ else:
94
+ self.g_norm = nn.GroupNorm(
95
+ num_groups=self.num_heads,
96
+ num_channels=self.value_dim,
97
+ eps=self.head_dim*norm_eps,
98
+ affine=elementwise_affine
99
+ )
100
+
101
+ self.apply(self._initialize_weights)
102
+
103
+ def _initialize_weights(self, module: nn.Module):
104
+ if getattr(module, "_is_hf_initialized", False):
105
+ return
106
+ if isinstance(module, nn.Linear):
107
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
108
+ if module.bias is not None:
109
+ nn.init.zeros_(module.bias)
110
+ if isinstance(module, nn.Parameter):
111
+ nn.init.xavier_uniform_(module, gain=2 ** -2.5)
112
+ module._is_hf_initialized = True
113
+
114
+ def forward(
115
+ self,
116
+ hidden_states: torch.Tensor,
117
+ attention_mask: Optional[torch.Tensor] = None,
118
+ past_key_values: Optional[Cache] = None,
119
+ use_cache: Optional[bool] = False,
120
+ output_attentions: Optional[bool] = False,
121
+ v_first: torch.Tensor = None,
122
+ **kwargs
123
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
124
+ if attention_mask is not None:
125
+ assert len(attention_mask.shape) == 2, (
126
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
127
+ "for padding purposes (0 indicating padding). "
128
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
129
+ )
130
+
131
+ batch_size, seq_len, _ = hidden_states.shape
132
+
133
+ if self.training:
134
+ # if training, use chunk mode no matter how short the sequence is
135
+ mode = 'chunk'
136
+ else:
137
+ # launching the triton kernel for just one token will actually be slower
138
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
139
+
140
+ last_state = None
141
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
142
+ last_state = past_key_values[self.layer_idx]
143
+
144
+ if attention_mask is not None:
145
+ hidden_states = hidden_states.mul(attention_mask[:, -hidden_states.shape[-2]:, None])
146
+ if hidden_states.shape[1] == 1 and last_state is not None:
147
+ shifted = last_state['conv_state'].unsqueeze(1)
148
+ else:
149
+ shifted = self.time_shift(hidden_states)
150
+ if last_state is not None:
151
+ shifted[:, 0] = last_state['conv_state']
152
+
153
+ # [batch_size, seq_len, hidden_size]
154
+ delta = shifted - hidden_states
155
+ xr, xw, xk, xv, xa, xg = hidden_states.addcmul(delta, self.x_x.view(6, 1, 1, -1)).unbind(0)
156
+
157
+ r = self.r_proj(xr)
158
+ # -math.exp(-0.5) = -0.6065306597126334
159
+ # I think .to(torch.float) is unnecessary here, since we calculate lora in bloat16
160
+ # when we apply sigmoid, bf16 input will not have numerical issue
161
+ # FIXME: check if we can remove .to(torch.float)
162
+ w = -0.6065306597126334 * self.w_lora(xw).to(torch.float).sigmoid()
163
+
164
+ k = self.k_proj(xk)
165
+ v = self.v_proj(xv)
166
+
167
+ if self.layer_idx == 0:
168
+ v_first = v
169
+ else:
170
+ v = torch.lerp(v, v_first, self.v_lora(xv).sigmoid())
171
+ a = self.a_lora(xa).sigmoid()
172
+ g = self.g_lora(xg)
173
+
174
+ if self.fuse_norm:
175
+ kk = l2_norm(rearrange(k * self.k_k, 'b t (h d) -> b t h d', d=self.head_dim))
176
+ else:
177
+ kk = F.normalize(rearrange(k * self.k_k, 'b t (h d) -> b t h d', d=self.head_dim), dim=-1, p=2.0)
178
+
179
+ k = k.addcmul(k * (a - 1), self.k_a)
180
+
181
+ # dealing with left-padding
182
+ if attention_mask is not None:
183
+ v = v * attention_mask[:, -v.shape[-2]:, None]
184
+ r, w, k, a = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_dim), (r, w, k, a))
185
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
186
+
187
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
188
+
189
+ rwkv7_fn = chunk_rwkv7 if mode == 'chunk' else fused_recurrent_rwkv7
190
+ cu_seqlens = kwargs.get('cu_seqlens', None)
191
+ o, recurrent_state = rwkv7_fn(
192
+ r=r,
193
+ w=w,
194
+ k=k,
195
+ v=v,
196
+ a=-kk,
197
+ b=kk * a,
198
+ scale=1.,
199
+ initial_state=recurrent_state,
200
+ output_final_state=use_cache,
201
+ cu_seqlens=cu_seqlens,
202
+ head_first=False
203
+ )
204
+
205
+ if past_key_values is not None:
206
+ past_key_values.update(
207
+ recurrent_state=recurrent_state,
208
+ conv_state=hidden_states[:, -1],
209
+ layer_idx=self.layer_idx,
210
+ offset=r.shape[1]
211
+ )
212
+
213
+ if self.fuse_norm:
214
+ o = self.g_norm(rearrange(o, '... h d -> ... (h d)'))
215
+ else:
216
+ o = self.g_norm(rearrange(o, 'b t h d -> (b t) (h d)')).view(batch_size, seq_len, -1)
217
+
218
+ o = o + ((r * k * self.r_k).sum(-1, keepdim=True) * v).view(batch_size, seq_len, -1)
219
+ o = self.o_proj(o * g)
220
+
221
+ return o, None, past_key_values, v_first
fla/modules/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from fla.modules.convolution import ImplicitLongConvolution, LongConvolution, ShortConvolution
4
+ from fla.modules.fused_bitlinear import BitLinear, FusedBitLinear
5
+ from fla.modules.fused_cross_entropy import FusedCrossEntropyLoss
6
+ from fla.modules.fused_kl_div import FusedKLDivLoss
7
+ from fla.modules.fused_linear_cross_entropy import FusedLinearCrossEntropyLoss
8
+ from fla.modules.fused_norm_gate import (
9
+ FusedLayerNormGated,
10
+ FusedLayerNormSwishGate,
11
+ FusedLayerNormSwishGateLinear,
12
+ FusedRMSNormGated,
13
+ FusedRMSNormSwishGate,
14
+ FusedRMSNormSwishGateLinear
15
+ )
16
+ from fla.modules.layernorm import GroupNorm, GroupNormLinear, LayerNorm, LayerNormLinear, RMSNorm, RMSNormLinear
17
+ from fla.modules.mlp import GatedMLP
18
+ from fla.modules.rotary import RotaryEmbedding
19
+
20
+ __all__ = [
21
+ 'ImplicitLongConvolution', 'LongConvolution', 'ShortConvolution',
22
+ 'BitLinear', 'FusedBitLinear',
23
+ 'FusedCrossEntropyLoss', 'FusedLinearCrossEntropyLoss', 'FusedKLDivLoss',
24
+ 'GroupNorm', 'GroupNormLinear', 'LayerNorm', 'LayerNormLinear', 'RMSNorm', 'RMSNormLinear',
25
+ 'FusedLayerNormGated', 'FusedLayerNormSwishGate', 'FusedLayerNormSwishGateLinear',
26
+ 'FusedRMSNormGated', 'FusedRMSNormSwishGate', 'FusedRMSNormSwishGateLinear',
27
+ 'GatedMLP',
28
+ 'RotaryEmbedding'
29
+ ]
fla/modules/convolution.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # from https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/convolution.py
4
+
5
+ import math
6
+ import warnings
7
+ from typing import Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import triton
13
+ import triton.language as tl
14
+ from einops import rearrange
15
+
16
+ from fla.modules.activations import ACT2FN
17
+ from fla.ops.common.utils import prepare_position_ids, prepare_sequence_ids
18
+ from fla.utils import checkpoint, input_guard
19
+
20
+ try:
21
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
22
+ except ImportError:
23
+ causal_conv1d_fn = None
24
+ causal_conv1d_update = None
25
+
26
+
27
+ def fft_conv(u, k, dropout_mask, gelu=True, k_rev=None):
28
+ seqlen = u.shape[-1]
29
+ fft_size = 2 * seqlen
30
+ k_f = torch.fft.rfft(k, n=fft_size) / fft_size
31
+ if k_rev is not None:
32
+ k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size
33
+ k_f = k_f + k_rev_f.conj()
34
+ u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
35
+
36
+ if len(u.shape) > 3:
37
+ k_f = k_f.unsqueeze(1)
38
+ y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen]
39
+
40
+ out = y + u
41
+ if gelu:
42
+ out = F.gelu(out)
43
+ if dropout_mask is not None:
44
+ return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype)
45
+ else:
46
+ return out.to(dtype=u.dtype)
47
+
48
+
49
+ @checkpoint
50
+ def proj_then_conv1d(
51
+ x: torch.Tensor,
52
+ proj_weight: torch.Tensor,
53
+ conv1d_weight: torch.Tensor,
54
+ conv1d_bias: Optional[torch.Tensor] = None,
55
+ cache: Optional[torch.Tensor] = None
56
+ ) -> torch.Tensor:
57
+ # We do matmul and transpose BLH -> HBL at the same time
58
+ x = rearrange(proj_weight @ rearrange(x, "b t d -> d (b t)"), "d (b t) -> b d t", t=x.shape[-2])
59
+
60
+ if causal_conv1d_fn is None:
61
+ raise ImportError("`causal_conv1d_fn` is not available. Please install `causal-conv1d` first.")
62
+ if cache is None:
63
+ x = causal_conv1d_fn(
64
+ x=x,
65
+ weight=rearrange(conv1d_weight, "d 1 w -> d w"),
66
+ bias=conv1d_bias,
67
+ activation="silu",
68
+ ).transpose(1, 2)
69
+ else:
70
+ assert x.shape[-1] == 1, "Only support decoding with 1 token at a time for now"
71
+ x = x.squeeze(-1)
72
+ x = causal_conv1d_update(
73
+ x=x,
74
+ weight=rearrange(conv1d_weight, "d 1 w -> d w"),
75
+ bias=conv1d_bias,
76
+ cache=cache,
77
+ activation="silu",
78
+ )
79
+ return x
80
+
81
+
82
+ @triton.jit
83
+ def causal_conv1d_varlen_states_fwd_kernel(
84
+ x,
85
+ cache,
86
+ offsets,
87
+ D,
88
+ W,
89
+ BD: tl.constexpr,
90
+ BW: tl.constexpr
91
+ ):
92
+ i_d, i_w, i_n = tl.program_id(0), tl.program_id(1), tl.program_id(2)
93
+ eos = tl.load(offsets + i_n + 1)
94
+ bos = tl.maximum(tl.load(offsets + i_n), eos - W)
95
+ o_t = eos - (i_w + 1) * BW + tl.arange(0, BW)
96
+ o_d = i_d * BD + tl.arange(0, BD)
97
+ o_w = W - (i_w + 1) * BW + tl.arange(0, BW)
98
+
99
+ b_x = tl.load(x + o_t * D + o_d[:, None], mask=(o_t >= bos) & (o_d[:, None] < D), other=0)
100
+ tl.store(cache + i_n * D*W + o_d[:, None] * W + o_w, b_x, mask=(o_d[:, None] < D) & (o_w >= 0))
101
+
102
+
103
+ @input_guard
104
+ def causal_conv1d_varlen_states_fwd(
105
+ x: torch.Tensor,
106
+ cache: torch.Tensor,
107
+ cu_seqlens: torch.Tensor,
108
+ state_len: int
109
+ ) -> torch.Tensor:
110
+ N, D, W = len(cu_seqlens) - 1, x.shape[-1], state_len
111
+ cache = torch.empty(N, D, W, dtype=x.dtype, device=x.device) if cache is None else cache
112
+ BD = min(triton.next_power_of_2(D), 256)
113
+ BW = min(triton.next_power_of_2(state_len), 16)
114
+ grid = (triton.cdiv(D, BD), triton.cdiv(W, BW), N)
115
+ with torch.cuda.device(x.device.index):
116
+ causal_conv1d_varlen_states_fwd_kernel[grid](
117
+ x=x,
118
+ cache=cache,
119
+ offsets=cu_seqlens,
120
+ D=D,
121
+ W=W,
122
+ BW=BW,
123
+ BD=BD
124
+ )
125
+ return cache
126
+
127
+
128
+ class ShortConvolution(nn.Conv1d):
129
+ """
130
+ Simple wrapper around `nn.Conv1d` that accepts dimension last.
131
+ """
132
+
133
+ def __init__(
134
+ self,
135
+ hidden_size: int,
136
+ kernel_size: int,
137
+ bias: bool = False,
138
+ activation: Optional[str] = 'silu',
139
+ use_fast_conv1d: Optional[bool] = True,
140
+ device: Optional[torch.device] = None,
141
+ dtype: Optional[torch.dtype] = None,
142
+ ):
143
+ super().__init__(
144
+ in_channels=hidden_size,
145
+ out_channels=hidden_size,
146
+ kernel_size=kernel_size,
147
+ groups=hidden_size,
148
+ bias=bias,
149
+ padding=kernel_size - 1,
150
+ device=device,
151
+ dtype=dtype,
152
+ )
153
+
154
+ self.hidden_size = hidden_size
155
+ self.activation = None
156
+ if activation is not None:
157
+ assert activation in ['silu', 'swish'], f"Activation `{activation}` not supported yet."
158
+ self.activation = activation
159
+
160
+ if causal_conv1d_fn is None:
161
+ if use_fast_conv1d:
162
+ raise RuntimeError(
163
+ "Please either install `causal-conv1d>=1.4.0` to enable fast causal short convolution CUDA kernel "
164
+ "or set `use_fast_conv1d` to False"
165
+ )
166
+ else:
167
+ warnings.warn(
168
+ "The naive Pytorch verison is very slow in practice, "
169
+ "please run `pip install causal-conv1d>=1.4.0` to install fast causal short convolution CUDA kernel",
170
+ category=ImportWarning
171
+ )
172
+ self.use_fast_conv1d = use_fast_conv1d
173
+
174
+ def extra_repr(self):
175
+ s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
176
+ ', stride={stride}')
177
+ if self.padding != (0,) * len(self.padding):
178
+ s += ', padding={padding}'
179
+ if self.dilation != (1,) * len(self.dilation):
180
+ s += ', dilation={dilation}'
181
+ if self.output_padding != (0,) * len(self.output_padding):
182
+ s += ', output_padding={output_padding}'
183
+ if self.groups != 1:
184
+ s += ', groups={groups}'
185
+ if self.bias is None:
186
+ s += ', bias=False'
187
+ if self.padding_mode != 'zeros':
188
+ s += ', padding_mode={padding_mode}'
189
+ if self.activation is not None:
190
+ s += ', activation={activation}'
191
+ if not self.use_fast_conv1d:
192
+ s += ', use_fast_conv1d={use_fast_conv1d}'
193
+ return s.format(**self.__dict__)
194
+
195
+ def forward(
196
+ self,
197
+ x: torch.Tensor,
198
+ mask: Optional[torch.Tensor] = None,
199
+ cache: Optional[torch.Tensor] = None,
200
+ output_final_state: bool = False,
201
+ cu_seqlens: Optional[torch.LongTensor] = None,
202
+ **kwargs,
203
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
204
+ """
205
+ Args:
206
+ x (`torch.Tensor`):
207
+ Tensor of shape `[B, T, D]`.
208
+ If `seq_idx` is provided, `B` must be 1.
209
+ mask (`Optional[torch.Tensor]`):
210
+ Attention mask dealing with padded positions.
211
+ cache (`Optional[torch.Tensor]`):
212
+ Previous cache tensor of shape `[N, D, W]`, where `W` is the kernel size.
213
+ If provided, the cache is updated **inplace**.
214
+ output_final_state (Optional[bool]):
215
+ Whether to output the final state of shape `[N, D, W]`. Default: `False`.
216
+ cu_seqlens (Optional[torch.LongTensor]):
217
+ Cumulative sequence lengths for each batch. Used for varlen. Default: `None`.
218
+ Shape: [B+1]
219
+
220
+ Returns:
221
+ Tensor of shape `[B, T, D]`.
222
+ """
223
+
224
+ B, T, D, W = *x.shape, self.kernel_size[0]
225
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
226
+ if mask is not None:
227
+ if cu_seqlens is not None:
228
+ raise ValueError("`mask` and `cu_seqlens` cannot be provided at the same time")
229
+ x = x.mul_(mask.unsqueeze(-1))
230
+ if output_final_state and cache is None:
231
+ cache = x.new_zeros(N, D, W)
232
+ # during the decoding phase, we assume the batch is composed of sequences of length 1
233
+ if cache is not None and B * T == N:
234
+ return self.step(x, cache, cu_seqlens)
235
+
236
+ if cache is not None:
237
+ if cu_seqlens is not None:
238
+ cache = causal_conv1d_varlen_states_fwd(x, cache, cu_seqlens, W)
239
+ else:
240
+ cache[:, :, -min(W, T):].copy_(rearrange(x[..., -min(W, T):, :], 'n w d -> n d w'))
241
+
242
+ x = rearrange(x, 'b t d -> b d t')
243
+ if self.use_fast_conv1d:
244
+ # Sequence index for each token. Used for varlen.
245
+ # Suppose a batch consists of two sequences with lengths 3 and 4,
246
+ # seq_idx=[0, 0, 0, 1, 1, 1, 1] for this batch.
247
+ # NOTE: No need to provide this arg if `cu_seqlens` is passed.
248
+ # This arg is just for BC, and will be removed in the future.
249
+ # [B, T]
250
+ seq_idx = kwargs.get('seq_idx', None)
251
+ if cu_seqlens is not None and seq_idx is None:
252
+ seq_idx = prepare_sequence_ids(prepare_position_ids(cu_seqlens)).to(torch.int32).unsqueeze(0)
253
+ x = causal_conv1d_fn(
254
+ x=x,
255
+ weight=rearrange(self.weight, "d 1 w -> d w"),
256
+ bias=self.bias,
257
+ activation=self.activation,
258
+ seq_idx=seq_idx,
259
+ )
260
+ else:
261
+ if cu_seqlens is not None:
262
+ raise ValueError("`cu_seqlens` is not supported for the naive Pytorch version")
263
+ x = self._conv_forward(x, self.weight, self.bias)[..., :x.shape[-1]]
264
+ if self.activation is not None:
265
+ x = ACT2FN[self.activation](x)
266
+ return rearrange(x, "b d t -> b t d"), cache
267
+
268
+ def step(
269
+ self,
270
+ x: torch.Tensor,
271
+ cache: torch.Tensor,
272
+ cu_seqlens: Optional[torch.LongTensor] = None
273
+ ):
274
+ shape = x.shape
275
+ x = x.squeeze(0) if cu_seqlens is not None else x.squeeze(1)
276
+ if self.use_fast_conv1d:
277
+ x = causal_conv1d_update(
278
+ x=x,
279
+ conv_state=cache,
280
+ weight=rearrange(self.weight, "d 1 w -> d w"),
281
+ bias=self.bias,
282
+ activation=self.activation,
283
+ )
284
+ else:
285
+ dtype = x.dtype
286
+ # we follow the fast mode that updates the cache in-place
287
+ cache.copy_(cache.roll(shifts=-1, dims=-1))
288
+ cache[:, :, -1] = x
289
+ x = torch.sum(cache * rearrange(self.weight, "d 1 w -> d w"), dim=-1)
290
+ if self.bias is not None:
291
+ x = x + self.bias
292
+ if self.activation is not None:
293
+ x = ACT2FN[self.activation](x).to(dtype=dtype)
294
+ return x.view(shape), cache
295
+
296
+ @property
297
+ def state_size(self) -> int:
298
+ return self.hidden_size * self.kernel_size
299
+
300
+
301
+ class LongConvolution(nn.Module):
302
+ """
303
+ LongConvolution applies a convolution operation on the input tensor using a fixed
304
+ filter of length max_len.
305
+ The filter is learned during training and is applied using FFT convolution.
306
+ Args:
307
+ hidden_size (int): The number of expected features in the input and output.
308
+ max_len (int): The maximum sequence length.
309
+ Returns:
310
+ y: [batch_size, seq_len, hidden_size] tensor
311
+ """
312
+
313
+ def __init__(
314
+ self,
315
+ hidden_size: int,
316
+ max_len: int,
317
+ **kwargs,
318
+ ):
319
+ """
320
+ Initializes the LongConvolution module.
321
+ Args:
322
+ hidden_size (int): The number of expected features in the input and output.
323
+ max_len (int): The maximum sequence length.
324
+ """
325
+ super().__init__()
326
+ self.hidden_size = hidden_size
327
+ self.filter = nn.Parameter(torch.randn(self.hidden_size, max_len), requires_grad=True)
328
+
329
+ def forward(self, x: torch.Tensor, *args, **kwargs):
330
+ """
331
+ Applies the LongConvolution operation on the input tensor.
332
+ Args:
333
+ x: [batch_size, seq_len, hidden_size] tensor
334
+ Returns:
335
+ y: [batch_size, seq_len, hidden_size] tensor
336
+ """
337
+ x = x.transpose(1, 2)
338
+ y = fft_conv(x, self.filter, dropout_mask=None, gelu=False)
339
+ y = y.transpose(1, 2)
340
+ return y.to(dtype=x.dtype)
341
+
342
+
343
+ class PositionalEmbedding(nn.Module):
344
+ def __init__(self, emb_dim: int, seq_len: int, **kwargs):
345
+ """Complex exponential positional embeddings for implicit long convolution filters."""
346
+ super().__init__()
347
+
348
+ self.seq_len = seq_len
349
+ # The time embedding fed to the filteres is normalized so that t_f = 1
350
+ t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1
351
+
352
+ if emb_dim > 1:
353
+ bands = (emb_dim - 1) // 2
354
+ # To compute the right embeddings we use the "proper" linspace
355
+ t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None]
356
+ w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1
357
+
358
+ f = torch.linspace(1e-4, bands - 1, bands)[None, None]
359
+ z = torch.exp(-1j * f * w)
360
+ z = torch.cat([t, z.real, z.imag], dim=-1)
361
+ self.z = nn.Parameter(z, requires_grad=False)
362
+
363
+ def forward(self, L):
364
+ return self.z[:, :L]
365
+
366
+
367
+ class ImplicitLongConvolution(nn.Module):
368
+ """
369
+ Long convolution with implicit filter parameterized by an MLP.
370
+
371
+ Args:
372
+ hidden_size (int):
373
+ The number of expected features in the input and output.
374
+ max_len (int):
375
+ The maximum sequence length.
376
+ d_emb (Optional[int]):
377
+ The dimension of the positional embeddings. Must be odd and greater or equal to 3 (time, sine and cosine).
378
+ Defaults to 3.
379
+ d_hidden (Optional[int]):
380
+ The number of features in the hidden layer of the MLP. Defaults to 16.
381
+
382
+ Attributes:
383
+ pos_emb (`PositionalEmbedding`): The positional embedding layer.
384
+ mlp (`nn.Sequential`): The MLP that parameterizes the implicit filter.
385
+
386
+ """
387
+
388
+ def __init__(
389
+ self,
390
+ hidden_size: int,
391
+ max_len: int,
392
+ d_emb: int = 3,
393
+ d_hidden: int = 16,
394
+ **kwargs,
395
+ ):
396
+ """
397
+ Long convolution with implicit filter parameterized by an MLP.
398
+
399
+
400
+ """
401
+ super().__init__()
402
+ self.hidden_size = hidden_size
403
+ self.d_emb = d_emb
404
+
405
+ assert (
406
+ d_emb % 2 != 0 and d_emb >= 3
407
+ ), "d_emb must be odd and greater or equal to 3 (time, sine and cosine)"
408
+ self.pos_emb = PositionalEmbedding(d_emb, max_len)
409
+
410
+ # final linear layer
411
+ self.mlp = nn.Sequential(
412
+ nn.Linear(d_emb, d_hidden),
413
+ torch.nn.ReLU(),
414
+ nn.Linear(d_hidden, hidden_size),
415
+ )
416
+
417
+ def filter(self, seq_len: int, *args, **kwargs):
418
+ k = self.mlp(self.pos_emb(seq_len))
419
+
420
+ return k.transpose(1, 2)
421
+
422
+ def forward(self, x: torch.Tensor, *args, **kwargs):
423
+ """
424
+ Args:
425
+ x: [batch_size, seq_len, hidden_size] tensor
426
+ Returns:
427
+ y: [batch_size, seq_len, hidden_size] tensor
428
+ """
429
+ x = x.transpose(1, 2)
430
+ k = self.filter(x.shape[-1])
431
+ y = fft_conv(x, k, dropout_mask=None, gelu=False)
432
+
433
+ y = y.transpose(1, 2)
434
+ return y.to(dtype=x.dtype)
fla/ops/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (2.09 kB). View file
 
fla/ops/based/fused_chunk.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
11
+
12
+
13
+ @triton.jit(do_not_specialize=['T'])
14
+ def fused_chunk_based_fwd_kernel(
15
+ q,
16
+ k,
17
+ v,
18
+ o,
19
+ z,
20
+ scale, # K ** -0.5
21
+ T,
22
+ B: tl.constexpr,
23
+ H: tl.constexpr,
24
+ K: tl.constexpr,
25
+ V: tl.constexpr,
26
+ BT: tl.constexpr,
27
+ BK: tl.constexpr,
28
+ BV: tl.constexpr,
29
+ ):
30
+ # indices
31
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
32
+
33
+ o_i = tl.arange(0, BT)
34
+
35
+ # [BT, BT]
36
+ m_s = o_i[:, None] >= o_i[None, :]
37
+
38
+ # [BV], zero-order taylor expansion
39
+ b_h_0o = tl.zeros([BV], dtype=tl.float32)
40
+ # [BK, BV], first-order taylor expansion
41
+ b_h_1o = tl.zeros([BK, BV], dtype=tl.float32)
42
+ # [BK, BK, BV] second-order taylor expansion
43
+ b_h_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)
44
+
45
+ # make block pointers
46
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BT, BK), (1, 0))
47
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BT), (0, 1))
48
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
49
+ p_o = tl.make_block_ptr(o + (i_bh + i_k*B*H) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
50
+
51
+ p_z = z + (i_bh + i_k * B * H) * T + tl.arange(0, BT)
52
+ k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)
53
+ k_1o = tl.zeros([1, BK], dtype=tl.float32)
54
+ k_0o = 0
55
+
56
+ for i in range(0, tl.cdiv(T, BT)):
57
+ # [BK, BT]
58
+ b_k = tl.load(p_k, boundary_check=(0, 1))
59
+ # [BK*BK, BT]
60
+ b_k_2o = b_k[:, None, :] * b_k[None, :, :]
61
+ b_k_2o = tl.reshape(b_k_2o, [BK * BK, BT]).to(b_k.dtype)
62
+ # [BT, BV]
63
+ b_v = tl.load(p_v, boundary_check=(0, 1))
64
+ # [BT, BK]
65
+ b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(b_k.dtype)
66
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
67
+ b_z = tl.zeros([BT], dtype=tl.float32)
68
+
69
+ # interchunk
70
+ # zero-order
71
+ b_o += b_h_0o
72
+ b_z += k_0o
73
+ # first-order
74
+ b_o += tl.dot(b_q, b_h_1o.to(b_q.dtype), allow_tf32=False)
75
+ b_z += tl.sum(b_q * k_1o, axis=1)
76
+ # second-order
77
+ b_q_2o = b_q[:, :, None] * b_q[:, None, :]
78
+ b_q_2o = tl.reshape(b_q_2o, [BT, BK * BK]).to(b_k.dtype)
79
+ b_o += tl.dot(b_q_2o, b_h_2o.to(b_q_2o.dtype), allow_tf32=False) * 0.5
80
+ b_z += tl.sum(b_q_2o * k_2o, axis=1) * 0.5
81
+
82
+ # update running statistics
83
+ k_1o += tl.sum(b_k, axis=1)[None, :]
84
+ k_2o += tl.sum(b_k_2o, axis=1)[None, :]
85
+ k_0o += BT
86
+
87
+ # intrachunk
88
+ # [BT, BT]
89
+ b_s = tl.dot(b_q, b_k, allow_tf32=False)
90
+ b_s = 1 + b_s + 0.5 * b_s * b_s
91
+ b_s = tl.where(m_s, b_s, 0)
92
+ b_z += tl.sum(b_s, axis=1)
93
+ b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
94
+ # [TB, BV]
95
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
96
+ tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=(i * BT + tl.arange(0, BT)) < T)
97
+
98
+ # update hidden state
99
+ # [BK, BV]
100
+ b_h_2o = b_h_2o + tl.dot(b_k_2o.to(b_v.dtype), b_v, allow_tf32=False)
101
+ b_h_1o = b_h_1o + tl.dot(b_k, b_v, allow_tf32=False)
102
+ b_h_0o = b_h_0o + tl.sum(b_v, axis=0)
103
+
104
+ p_q = tl.advance(p_q, (BT, 0))
105
+ p_k = tl.advance(p_k, (0, BT))
106
+ p_v = tl.advance(p_v, (BT, 0))
107
+ p_o = tl.advance(p_o, (BT, 0))
108
+ p_z += BT
109
+
110
+
111
+ # Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
112
+ @triton.jit
113
+ def fused_chunk_based_bwd_kernel(
114
+ # NV: number of split in the V dimension. NK: number of split in the K dimension
115
+ q,
116
+ k,
117
+ v,
118
+ do,
119
+ dz,
120
+ dq,
121
+ dk,
122
+ dv,
123
+ scale, # K ** -0.5
124
+ T,
125
+ B: tl.constexpr,
126
+ H: tl.constexpr,
127
+ K: tl.constexpr,
128
+ V: tl.constexpr,
129
+ BT: tl.constexpr,
130
+ BK: tl.constexpr,
131
+ BV: tl.constexpr,
132
+ ):
133
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
134
+
135
+ o_i = tl.arange(0, BT)
136
+ m_s = o_i[:, None] >= o_i[None, :]
137
+
138
+ # [BV], zero-order taylor expansion
139
+ # b_h_0o = tl.zeros([BV], dtype=tl.float32)
140
+ # [BK, BV], first-order taylor expansion
141
+ b_h_1o = tl.zeros([BV, BK], dtype=tl.float32)
142
+ # [BK, BK, BV] second-order taylor expansion
143
+ b_h_2o = tl.zeros([BV, BK*BK], dtype=tl.float32)
144
+
145
+ k_1o = tl.zeros([1, BK], dtype=tl.float32)
146
+ k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)
147
+
148
+ for i in range(0, tl.cdiv(T, BT)):
149
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0))
150
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0))
151
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i * BT), (BV, BT), (0, 1))
152
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i * BT, i_v * BV), (BT, BV), (1, 0))
153
+ p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * T*K, (T, K), (K, 1), (i*BT, i_k*BK), (BT, BK), (1, 0))
154
+ p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i * BT
155
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
156
+
157
+ # load tensors
158
+ # [BT, BK]
159
+ b_q = tl.load(p_q, boundary_check=(0, 1))
160
+ b_q = (b_q * scale).to(b_q.dtype)
161
+ b_k = tl.load(p_k, boundary_check=(0, 1))
162
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
163
+ b_dz = tl.load(p_dz, mask=(tl.arange(0, BT) + i * BT) < T)
164
+ # [BV, BT]
165
+ b_v = tl.load(p_v, boundary_check=(0, 1))
166
+
167
+ # inter-chunk
168
+ b_dq += tl.dot(b_do, (b_h_1o).to(b_do.dtype), allow_tf32=False)
169
+ if i_v == 0:
170
+ b_dq += b_dz[:, None] * k_1o
171
+ b_dq_2o = tl.dot(b_do, (b_h_2o).to(b_do.dtype), allow_tf32=False) * 0.5
172
+ if i_v == 0:
173
+ b_dq_2o += (b_dz[:, None] * k_2o) * 0.5
174
+ b_dq_2o = tl.reshape(b_dq_2o, [BT, BK, BK])
175
+ b_dq += tl.sum(b_dq_2o * b_q[:, :, None], axis=1)
176
+ b_dq += tl.sum(b_dq_2o * b_q[:, None, :], axis=2)
177
+ b_dq *= scale
178
+
179
+ # intra-chunk
180
+ # [BT, BT]
181
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
182
+ if i_v == 0:
183
+ b_ds += b_dz[:, None]
184
+ b_ds = tl.where(m_s, b_ds, 0) * scale
185
+ b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
186
+ b_s = tl.where(m_s, b_s, 0)
187
+ b_dq += tl.dot((b_ds * (1 + b_s)).to(b_q.dtype), b_k, allow_tf32=False)
188
+
189
+ # store
190
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
191
+
192
+ # update hidden state
193
+ # [BT, BK*BK]
194
+ b_k_2o = b_k[:, :, None] * b_k[:, None, :]
195
+ b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)
196
+ # [BV, BK*BK]
197
+ b_h_2o = b_h_2o + tl.dot(b_v, b_k_2o.to(b_v.dtype), allow_tf32=False)
198
+ # [BV, BK]
199
+ b_h_1o = b_h_1o + tl.dot(b_v, b_k, allow_tf32=False)
200
+
201
+ if i_v == 0:
202
+ # update running statistics
203
+ k_1o += tl.sum(b_k, axis=0)[None, :]
204
+ k_2o += tl.sum(b_k_2o, axis=0)[None, :]
205
+
206
+ tl.debug_barrier()
207
+ b_h_1o = None
208
+ b_h_2o = None
209
+
210
+ # [BK, BV], first-order taylor expansion
211
+ b_dh_1o = tl.zeros([BK, BV], dtype=tl.float32)
212
+ # [BK, BK, BV] second-order taylor expansion
213
+ b_dh_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)
214
+ b_dh_0o = tl.zeros([BV], dtype=tl.float32)
215
+ m_s = tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]
216
+
217
+ dq_1o = tl.zeros([1, BK], dtype=tl.float32)
218
+ dq_2o = tl.zeros([BK * BK, 1], dtype=tl.float32)
219
+
220
+ for i in range(tl.cdiv(T, BT) * BT - BT, -BT, -BT):
221
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i), (BK, BT), (0, 1))
222
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i, i_k * BK), (BT, BK), (1, 0))
223
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i, i_v * BV), (BT, BV), (1, 0))
224
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i, i_v * BV), (BT, BV), (1, 0))
225
+ p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * T*K, (T, K), (K, 1), (i, i_k*BK), (BT, BK), (1, 0))
226
+ p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * T*V, (T, V), (V, 1), (i, i_v*BV), (BT, BV), (1, 0))
227
+ p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i
228
+
229
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
230
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
231
+
232
+ b_q = tl.load(p_q, boundary_check=(0, 1))
233
+ b_k = tl.load(p_k, boundary_check=(0, 1))
234
+ b_v = tl.load(p_v, boundary_check=(0, 1))
235
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
236
+ b_dz = tl.load(p_dz, mask=(tl.arange(0, BT)+i) < T)
237
+ b_q = (b_q * scale).to(b_k.dtype)
238
+
239
+ # intra chunk
240
+ b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
241
+ if i_v == 0:
242
+ b_ds += b_dz[None, :]
243
+ b_ds = tl.where(m_s, b_ds, 0)
244
+ b_s = tl.dot(b_k, b_q, allow_tf32=False)
245
+ b_s2 = 1 + b_s + 0.5 * b_s * b_s
246
+ b_s = tl.where(m_s, b_s, 0)
247
+ b_s2 = tl.where(m_s, b_s2, 0)
248
+ b_ds *= (1+b_s)
249
+
250
+ b_dk += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_q), allow_tf32=False)
251
+ b_dv += tl.dot(b_s2.to(b_do.dtype), b_do, allow_tf32=False)
252
+
253
+ # inter chunk
254
+ b_k_2o = b_k[:, :, None] * b_k[:, None, :]
255
+ b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)
256
+
257
+ b_dv += tl.dot(b_k, b_dh_1o.to(b_k.dtype), allow_tf32=False)
258
+ b_dv += tl.dot(b_k_2o, b_dh_2o.to(b_k.dtype), allow_tf32=False)
259
+ b_dv += b_dh_0o
260
+
261
+ b_dk += tl.dot(b_v, tl.trans(b_dh_1o).to(b_k.dtype), allow_tf32=False)
262
+
263
+ if i_v == 0:
264
+ b_dk += dq_1o
265
+
266
+ b_dk_2o = tl.dot(b_dh_2o.to(b_k.dtype), tl.trans(b_v), allow_tf32=False)
267
+ if i_v == 0:
268
+ b_dk_2o += dq_2o
269
+ b_dk_2o = tl.reshape(b_dk_2o, [BK, BK, BT])
270
+ b_k_fp32 = tl.trans(b_k.to(tl.float32))
271
+ b_dk2 = tl.sum(b_dk_2o * b_k_fp32[:, None, :], axis=0)
272
+ b_dk2 += tl.sum(b_dk_2o * b_k_fp32[None, :, :], axis=1)
273
+ b_dk += tl.trans(b_dk2)
274
+
275
+ # hidden state update
276
+ b_dh_0o += tl.sum(b_do, axis=0)
277
+ b_dh_1o = b_dh_1o + tl.dot(b_q, b_do, allow_tf32=False)
278
+ b_q_2o = b_q[None, :, :] * b_q[:, None, :]
279
+ b_q_2o = tl.reshape(b_q_2o, [BK * BK, BT]).to(b_k.dtype)
280
+ b_dh_2o = b_dh_2o + tl.dot(b_q_2o, b_do, allow_tf32=False) * 0.5
281
+
282
+ if i_v == 0:
283
+ dq_1o += (tl.sum(b_dz[None, :] * b_q, axis=1))[None, :]
284
+ dq_2o += (tl.sum(b_dz[None, :] * b_q_2o, axis=1) * 0.5)[:, None]
285
+
286
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
287
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
288
+
289
+
290
+ class FusedChunkBasedFunction(torch.autograd.Function):
291
+
292
+ @staticmethod
293
+ @input_guard
294
+ @autocast_custom_fwd
295
+ def forward(ctx, q, k, v, scale=1):
296
+ B, H, T, K, V = *k.shape, v.shape[-1]
297
+
298
+ scale = scale
299
+ BT = 16
300
+ BK, BV = min(K, 16), min(V, 32)
301
+ BK, BV = max(BK, 16), max(BV, 16)
302
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
303
+
304
+ num_warps = 4
305
+
306
+ # the norm of o might explode, so we need to use float32 here
307
+ o = q.new_empty(NK, B, H, T, V, dtype=torch.float32)
308
+ z = q.new_empty(NK, B, H, T, dtype=torch.float32)
309
+
310
+ grid = (NV, NK, B * H)
311
+ fused_chunk_based_fwd_kernel[grid](
312
+ q, k, v, o, z,
313
+ scale,
314
+ T=T, B=B, H=H, K=K, V=V, BT=BT, BK=BK, BV=BV,
315
+ num_warps=num_warps,
316
+ )
317
+ o = o.sum(0)
318
+ z = z.sum(0)
319
+ ctx.save_for_backward(q, k, v)
320
+ ctx.scale = scale
321
+ return o.to(q.dtype), z.to(z.dtype)
322
+
323
+ @staticmethod
324
+ @input_guard
325
+ @autocast_custom_bwd
326
+ def backward(ctx, do, dz):
327
+ q, k, v = ctx.saved_tensors
328
+ B, H, T, K, V = *k.shape, v.shape[-1]
329
+ scale = ctx.scale
330
+
331
+ BT = 16
332
+ BK, BV = min(K, 16), min(V, 32)
333
+ BK, BV = max(BK, 16), max(BV, 16)
334
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
335
+ num_stages = 1
336
+ num_warps = 4
337
+
338
+ dq = q.new_empty(NV, B, H, T, K)
339
+ dk = q.new_empty(NV, B, H, T, K)
340
+ dv = q.new_empty(NK, B, H, T, V)
341
+ grid = (NV, NK, B * H)
342
+
343
+ fused_chunk_based_bwd_kernel[grid](
344
+ q, k, v, do, dz, dq, dk, dv,
345
+ scale,
346
+ T=T, B=B, H=H, K=K, V=V, BT=BT, BK=BK, BV=BV,
347
+ num_warps=num_warps,
348
+ num_stages=num_stages
349
+ )
350
+ dq = dq.sum(0)
351
+ dk = dk.sum(0)
352
+ dv = dv.sum(0)
353
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None
354
+
355
+
356
+ def fused_chunk_based(
357
+ q: torch.Tensor,
358
+ k: torch.Tensor,
359
+ v: torch.Tensor,
360
+ scale: Optional[float] = None,
361
+ use_norm: bool = True,
362
+ head_first: bool = True
363
+ ):
364
+ assert q.shape[-1] <= 16, 'only support feature dimension up to 16.'
365
+ if scale is None:
366
+ scale = q.shape[-1] ** -0.5
367
+ if not head_first:
368
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
369
+ o, z = FusedChunkBasedFunction.apply(q, k, v, scale)
370
+ if use_norm:
371
+ o = o / (z[..., None] + 1e-6)
372
+ if not head_first:
373
+ o = o.transpose(1, 2)
374
+ return o.to(q.dtype)
fla/ops/common/chunk_h_parallel.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ """
5
+ Fully parallelized state passing.
6
+ """
7
+
8
+ from typing import Optional, Tuple
9
+
10
+ import torch
11
+ import triton
12
+ import triton.language as tl
13
+
14
+ from fla.ops.utils.op import exp
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
19
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
20
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
21
+ })
22
+ @triton.autotune(
23
+ configs=[
24
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
25
+ for BK in [32, 64, 128]
26
+ for BV in [32, 64, 128]
27
+ for num_warps in [2, 4, 8]
28
+ for num_stages in [2, 3, 4]
29
+ ],
30
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def chunk_fwd_kernel_h_parallel(
34
+ k,
35
+ v,
36
+ h,
37
+ g,
38
+ gk,
39
+ gv,
40
+ h0,
41
+ ht,
42
+ offsets,
43
+ indices,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ USE_G: tl.constexpr,
52
+ USE_GK: tl.constexpr,
53
+ USE_GV: tl.constexpr,
54
+ USE_INITIAL_STATE: tl.constexpr,
55
+ STORE_FINAL_STATE: tl.constexpr,
56
+ USE_OFFSETS: tl.constexpr,
57
+ HEAD_FIRST: tl.constexpr
58
+ ):
59
+ i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
60
+
61
+ NV = tl.cdiv(V, BV)
62
+ # i_b: batch index
63
+ # i_h: head index
64
+ # i_n: sequence index
65
+ # i_t: chunk index within current sequence
66
+ # i_tg: (global) chunk index across all sequences
67
+ i_k, i_v = i_kv // NV, i_kv % NV
68
+ i_b, i_h = i_bh // H, i_bh % H
69
+ if USE_OFFSETS:
70
+ i_tg = i_t
71
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
72
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
73
+ T = eos - bos
74
+ NT = tl.cdiv(T, BT)
75
+ else:
76
+ bos, eos = i_b * T, i_b * T + T
77
+ NT = tl.cdiv(T, BT)
78
+ i_n, i_tg = i_b, i_b * NT + i_t
79
+ i_nh = i_n * H + i_h
80
+
81
+ if HEAD_FIRST:
82
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
83
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
84
+ p_h = tl.make_block_ptr(h + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
85
+ else:
86
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
87
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
88
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
89
+
90
+ if i_t == 0:
91
+ if USE_INITIAL_STATE:
92
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
93
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
94
+ else:
95
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
96
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
97
+
98
+ # [BK, BT]
99
+ b_k = tl.load(p_k, boundary_check=(0, 1))
100
+ # [BT, BV]
101
+ b_v = tl.load(p_v, boundary_check=(0, 1))
102
+
103
+ last_idx = min(i_t * BT + BT, T) - 1
104
+ # scalar decay
105
+ if USE_G:
106
+ if HEAD_FIRST:
107
+ b_g_last = tl.load(g + i_bh * T + last_idx)
108
+ p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)
109
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
110
+ else:
111
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
112
+ p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h
113
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
114
+ b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype)
115
+
116
+ # vector decay, h = Diag(gk) @ h
117
+ if USE_GK:
118
+ if HEAD_FIRST:
119
+ p_gk = tl.make_block_ptr(gk + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
120
+ p_gk_last = gk + i_bh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
121
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
122
+ else:
123
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
124
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
125
+
126
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
127
+
128
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
129
+ b_k = (b_k * exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype)
130
+
131
+ # vector decay, h = h @ Diag(gv)
132
+ if USE_GV:
133
+ if HEAD_FIRST:
134
+ p_gv = tl.make_block_ptr(gv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
135
+ p_gv_last = gv + i_bh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
136
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
137
+ else:
138
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
139
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
140
+
141
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
142
+
143
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
144
+ b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype)
145
+
146
+ b_h = tl.dot(b_k, b_v)
147
+ if i_t < NT - 1:
148
+ if HEAD_FIRST:
149
+ p_h = tl.make_block_ptr(h + (i_bh * NT + i_t + 1) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
150
+ else:
151
+ p_h = tl.make_block_ptr(h + ((i_tg + 1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
152
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
153
+ elif STORE_FINAL_STATE:
154
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
155
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
156
+
157
+
158
+ @triton.heuristics({
159
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
160
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
161
+ })
162
+ @triton.autotune(
163
+ configs=[
164
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
165
+ for BK in [32, 64, 128]
166
+ for BV in [32, 64, 128]
167
+ for num_warps in [2, 4, 8, 16]
168
+ for num_stages in [2, 3]
169
+ ],
170
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
171
+ )
172
+ @triton.jit(do_not_specialize=['T'])
173
+ def chunk_fwd_kernel_h_reduction(
174
+ h,
175
+ g,
176
+ gk,
177
+ gv,
178
+ kvt,
179
+ ht,
180
+ offsets,
181
+ chunk_offsets,
182
+ T,
183
+ H: tl.constexpr,
184
+ K: tl.constexpr,
185
+ V: tl.constexpr,
186
+ BT: tl.constexpr,
187
+ BK: tl.constexpr,
188
+ BV: tl.constexpr,
189
+ USE_G: tl.constexpr,
190
+ USE_GK: tl.constexpr,
191
+ USE_GV: tl.constexpr,
192
+ STORE_FINAL_STATE: tl.constexpr,
193
+ USE_OFFSETS: tl.constexpr,
194
+ HEAD_FIRST: tl.constexpr
195
+ ):
196
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
197
+ i_n, i_h = i_nh // H, i_nh % H
198
+ if USE_OFFSETS:
199
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
200
+ T = eos - bos
201
+ NT = tl.cdiv(T, BT)
202
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
203
+ else:
204
+ bos, eos = i_n * T, i_n * T + T
205
+ NT = tl.cdiv(T, BT)
206
+ boh = i_n * NT
207
+
208
+ # [BK, BV]
209
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
210
+ for i_t in range(NT):
211
+ if HEAD_FIRST:
212
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
213
+ else:
214
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
215
+ b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
216
+ if i_t > 0:
217
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
218
+
219
+ last_idx = min(i_t * BT + BT, T) - 1
220
+ # scalar decay
221
+ if USE_G:
222
+ if HEAD_FIRST:
223
+ b_g_last = tl.load(g + i_nh * T + last_idx)
224
+ else:
225
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
226
+ b_h *= exp(b_g_last)
227
+
228
+ # vector decay, h = Diag(gk) @ h
229
+ if USE_GK:
230
+ if HEAD_FIRST:
231
+ p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
232
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
233
+ else:
234
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
235
+
236
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
237
+ b_h *= exp(b_gk_last)[:, None]
238
+
239
+ # vector decay, h = h @ Diag(gv)
240
+ if USE_GV:
241
+ if HEAD_FIRST:
242
+ p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
243
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
244
+ else:
245
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
246
+
247
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
248
+ b_h *= exp(b_gv_last)[None, :]
249
+
250
+ if STORE_FINAL_STATE:
251
+ p_kvt = tl.make_block_ptr(kvt + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
252
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
253
+ b_h += tl.load(p_kvt, boundary_check=(0, 1)).to(tl.float32)
254
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
255
+
256
+
257
+ @triton.heuristics({
258
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
259
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
260
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
261
+ })
262
+ @triton.autotune(
263
+ configs=[
264
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
265
+ for BK in [32, 64, 128]
266
+ for BV in [32, 64, 128]
267
+ for num_warps in [2, 4, 8]
268
+ for num_stages in [2, 3, 4]
269
+ ],
270
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
271
+ )
272
+ @triton.jit(do_not_specialize=['T'])
273
+ def chunk_bwd_kernel_dh_parallel(
274
+ q,
275
+ g,
276
+ gk,
277
+ gv,
278
+ do,
279
+ dh,
280
+ dht,
281
+ dh0,
282
+ offsets,
283
+ indices,
284
+ scale,
285
+ T,
286
+ HQ: tl.constexpr,
287
+ H: tl.constexpr,
288
+ K: tl.constexpr,
289
+ V: tl.constexpr,
290
+ BT: tl.constexpr,
291
+ BK: tl.constexpr,
292
+ BV: tl.constexpr,
293
+ NG: tl.constexpr,
294
+ USE_G: tl.constexpr,
295
+ USE_GK: tl.constexpr,
296
+ USE_GV: tl.constexpr,
297
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
298
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
299
+ USE_OFFSETS: tl.constexpr,
300
+ HEAD_FIRST: tl.constexpr
301
+ ):
302
+ i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
303
+
304
+ NV = tl.cdiv(V, BV)
305
+ i_k, i_v = i_kv // NV, i_kv % NV
306
+ i_b, i_hq, i_bg = i_bh // HQ, i_bh % HQ, i_bh // NG
307
+ i_h = i_hq // NG
308
+ if USE_OFFSETS:
309
+ i_tg = i_t
310
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
311
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
312
+ T = eos - bos
313
+ NT = tl.cdiv(T, BT)
314
+ else:
315
+ bos, eos = i_b * T, i_b * T + T
316
+ NT = tl.cdiv(T, BT)
317
+ i_n, i_tg = i_b, i_b * NT + i_t
318
+ i_nh = i_n * HQ + i_hq
319
+
320
+ if HEAD_FIRST:
321
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
322
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
323
+ p_dh = tl.make_block_ptr(dh + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
324
+ else:
325
+ p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
326
+ p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
327
+ p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
328
+
329
+ if i_t == NT - 1:
330
+ if USE_FINAL_STATE_GRADIENT:
331
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
332
+ b_dh = tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32)
333
+ else:
334
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
335
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
336
+
337
+ # [BK, BT]
338
+ b_q = tl.load(p_q, boundary_check=(0, 1))
339
+ b_q = (b_q * scale).to(b_q.dtype)
340
+ # [BT, BV]
341
+ b_do = tl.load(p_do, boundary_check=(0, 1))
342
+
343
+ if USE_G:
344
+ if HEAD_FIRST:
345
+ p_g = g + i_bg * T + i_t * BT + tl.arange(0, BT)
346
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
347
+ else:
348
+ p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h
349
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
350
+ b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype)
351
+
352
+ if USE_GK:
353
+ if HEAD_FIRST:
354
+ p_gk = tl.make_block_ptr(gk + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
355
+ else:
356
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
357
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
358
+ b_q = (b_q * exp(b_gk)).to(b_q.dtype)
359
+
360
+ if USE_GV:
361
+ if HEAD_FIRST:
362
+ p_gv = tl.make_block_ptr(gv + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
363
+ else:
364
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
365
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
366
+ b_do = (b_do * exp(b_gv)).to(b_do.dtype)
367
+
368
+ b_dh = tl.dot(b_q, b_do)
369
+ if i_t > 0:
370
+ if HEAD_FIRST:
371
+ p_dh = tl.make_block_ptr(dh + (i_bh * NT + i_t - 1) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
372
+ else:
373
+ p_dh = tl.make_block_ptr(dh + ((i_tg - 1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
374
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
375
+ elif STORE_INITIAL_STATE_GRADIENT:
376
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
377
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
378
+
379
+
380
+ @triton.heuristics({
381
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
382
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
383
+ })
384
+ @triton.autotune(
385
+ configs=[
386
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
387
+ for BK in [32, 64, 128]
388
+ for BV in [32, 64, 128]
389
+ for num_warps in [2, 4, 8, 16]
390
+ for num_stages in [2, 3]
391
+ ],
392
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
393
+ )
394
+ @triton.jit(do_not_specialize=['T'])
395
+ def chunk_bwd_kernel_dh_reduction(
396
+ g,
397
+ gk,
398
+ gv,
399
+ dh,
400
+ doq0,
401
+ dh0,
402
+ offsets,
403
+ chunk_offsets,
404
+ T,
405
+ HQ: tl.constexpr,
406
+ H: tl.constexpr,
407
+ K: tl.constexpr,
408
+ V: tl.constexpr,
409
+ BT: tl.constexpr,
410
+ BK: tl.constexpr,
411
+ BV: tl.constexpr,
412
+ NG: tl.constexpr,
413
+ USE_G: tl.constexpr,
414
+ USE_GK: tl.constexpr,
415
+ USE_GV: tl.constexpr,
416
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
417
+ USE_OFFSETS: tl.constexpr,
418
+ HEAD_FIRST: tl.constexpr
419
+ ):
420
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
421
+ i_bg = i_nh // NG
422
+ i_n, i_hq = i_nh // HQ, i_nh % HQ
423
+ i_h = i_hq // NG
424
+ if USE_OFFSETS:
425
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
426
+ T = eos - bos
427
+ NT = tl.cdiv(T, BT)
428
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
429
+ else:
430
+ bos, eos = i_n * T, i_n * T + T
431
+ NT = tl.cdiv(T, BT)
432
+ boh = i_n * NT
433
+
434
+ # [BK, BV]
435
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
436
+ for i_t in range(NT - 1, -1, -1):
437
+ if HEAD_FIRST:
438
+ p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
439
+ else:
440
+ p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
441
+ b_dh += tl.load(p_dh, boundary_check=(0, 1)).to(tl.float32)
442
+ if i_t < NT - 1:
443
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
444
+
445
+ last_idx = min(i_t * BT + BT, T) - 1
446
+ if USE_G:
447
+ if HEAD_FIRST:
448
+ b_g_last = tl.load(g + i_bg * T + last_idx)
449
+ else:
450
+ b_g_last = tl.load(g + (bos + last_idx) * H + i_h)
451
+ b_dh *= exp(b_g_last)
452
+
453
+ if USE_GK:
454
+ if HEAD_FIRST:
455
+ p_gk_last = gk + (i_bg * T + last_idx) * K + i_k * BK + tl.arange(0, BK)
456
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
457
+ else:
458
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
459
+
460
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
461
+ b_dh *= exp(b_gk_last)[:, None]
462
+
463
+ if USE_GV:
464
+ if HEAD_FIRST:
465
+ p_gv_last = gv + (i_bg * T + last_idx) * V + i_v * BV + tl.arange(0, BV)
466
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
467
+ else:
468
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
469
+
470
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
471
+ b_dh *= exp(b_gv_last)[None, :]
472
+
473
+ if STORE_INITIAL_STATE_GRADIENT:
474
+ p_doq0 = tl.make_block_ptr(doq0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
475
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
476
+ b_dh += tl.load(p_doq0, boundary_check=(0, 1)).to(tl.float32)
477
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
478
+
479
+
480
+ def chunk_fwd_h(
481
+ k: torch.Tensor,
482
+ v: torch.Tensor,
483
+ g: torch.Tensor,
484
+ gk: torch.Tensor,
485
+ gv: torch.Tensor,
486
+ h0: torch.Tensor,
487
+ output_final_state: bool,
488
+ states_in_fp32: bool = False,
489
+ offsets: Optional[torch.Tensor] = None,
490
+ indices: Optional[torch.Tensor] = None,
491
+ head_first: bool = True,
492
+ chunk_size: int = 64
493
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
494
+ if head_first:
495
+ B, H, T, K, V = *k.shape, v.shape[-1]
496
+ else:
497
+ B, T, H, K, V = *k.shape, v.shape[-1]
498
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
499
+ # N: the actual number of sequences in the batch with either equal or variable lengths
500
+ if offsets is None:
501
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
502
+ else:
503
+ if indices is None:
504
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()])
505
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
506
+ N, NT = len(offsets) - 1, len(indices)
507
+ chunk_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1)
508
+
509
+ h = k.new_empty(B, H, NT, K, V, dtype=torch.float) if head_first else k.new_empty(B, NT, H, K, V, dtype=torch.float)
510
+ ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None
511
+ def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), NT, B * H)
512
+ chunk_fwd_kernel_h_parallel[grid](
513
+ k=k,
514
+ v=v,
515
+ h=h,
516
+ g=g,
517
+ gk=gk,
518
+ gv=gv,
519
+ h0=h0,
520
+ ht=ht,
521
+ offsets=offsets,
522
+ indices=indices,
523
+ T=T,
524
+ H=H,
525
+ K=K,
526
+ V=V,
527
+ BT=BT,
528
+ USE_G=g is not None,
529
+ USE_GK=gk is not None,
530
+ USE_GV=gv is not None,
531
+ HEAD_FIRST=head_first
532
+ )
533
+ kvt, ht = ht, (torch.empty_like(ht) if output_final_state else None)
534
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H)
535
+ chunk_fwd_kernel_h_reduction[grid](
536
+ h=h,
537
+ g=g,
538
+ gk=gk,
539
+ gv=gv,
540
+ kvt=kvt,
541
+ ht=ht,
542
+ offsets=offsets,
543
+ chunk_offsets=chunk_offsets,
544
+ T=T,
545
+ H=H,
546
+ K=K,
547
+ V=V,
548
+ BT=BT,
549
+ USE_G=g is not None,
550
+ USE_GK=gk is not None,
551
+ USE_GV=gv is not None,
552
+ HEAD_FIRST=head_first
553
+ )
554
+ h = h.to(k.dtype) if not states_in_fp32 else h
555
+ return h, ht
556
+
557
+
558
+ def chunk_bwd_dh(
559
+ q: torch.Tensor,
560
+ k: torch.Tensor,
561
+ v: torch.Tensor,
562
+ g: torch.Tensor,
563
+ gk: torch.Tensor,
564
+ gv: torch.Tensor,
565
+ do: torch.Tensor,
566
+ h0: torch.Tensor,
567
+ dht: torch.Tensor,
568
+ scale: float,
569
+ states_in_fp32: bool = False,
570
+ offsets: Optional[torch.Tensor] = None,
571
+ indices: Optional[torch.Tensor] = None,
572
+ head_first: bool = True,
573
+ chunk_size: int = 64
574
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
575
+ if head_first:
576
+ B, H, T, K, V = *k.shape, v.shape[-1]
577
+ HQ = q.shape[1]
578
+ else:
579
+ B, T, H, K, V = *k.shape, v.shape[-1]
580
+ HQ = q.shape[2]
581
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
582
+ # N: the actual number of sequences in the batch with either equal or variable lengths
583
+ # NG: number of groups in GQA
584
+ if offsets is None:
585
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
586
+ else:
587
+ if indices is None:
588
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()])
589
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
590
+ N, NT = len(offsets) - 1, len(indices)
591
+ chunk_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1)
592
+ NG = HQ // H
593
+
594
+ if head_first:
595
+ dh = k.new_empty(B, HQ, NT, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
596
+ else:
597
+ dh = k.new_empty(B, NT, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
598
+ dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None
599
+
600
+ def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), NT, B * HQ)
601
+ chunk_bwd_kernel_dh_parallel[grid](
602
+ q=q,
603
+ g=g,
604
+ gk=gk,
605
+ gv=gv,
606
+ do=do,
607
+ dh=dh,
608
+ dht=dht,
609
+ dh0=dh0,
610
+ offsets=offsets,
611
+ indices=indices,
612
+ scale=scale,
613
+ T=T,
614
+ HQ=HQ,
615
+ H=H,
616
+ K=K,
617
+ V=V,
618
+ BT=BT,
619
+ NG=NG,
620
+ USE_G=g is not None,
621
+ USE_GK=gk is not None,
622
+ USE_GV=gv is not None,
623
+ HEAD_FIRST=head_first
624
+ )
625
+
626
+ doq0, dh0 = dh0, (torch.empty_like(dh0) if dh0 is not None else None)
627
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * HQ)
628
+ chunk_bwd_kernel_dh_reduction[grid](
629
+ g=g,
630
+ gk=gk,
631
+ gv=gv,
632
+ dh=dh,
633
+ doq0=doq0,
634
+ dh0=dh0,
635
+ offsets=offsets,
636
+ chunk_offsets=chunk_offsets,
637
+ T=T,
638
+ HQ=HQ,
639
+ H=H,
640
+ K=K,
641
+ V=V,
642
+ BT=BT,
643
+ NG=NG,
644
+ USE_G=g is not None,
645
+ USE_GK=gk is not None,
646
+ USE_GV=gv is not None,
647
+ HEAD_FIRST=head_first
648
+ )
649
+ dh = dh.to(q.dtype) if not states_in_fp32 else dh
650
+ return dh, dh0
fla/ops/common/fused_recurrent.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils import chunk_global_cumsum
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps)
23
+ for num_warps in [1, 2, 4]
24
+ ],
25
+ key=["BK", "BV", "USE_GK", "USE_GV", "USE_G"],
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def fused_recurrent_fwd_kernel(
29
+ q,
30
+ k,
31
+ v,
32
+ g,
33
+ gk,
34
+ gv,
35
+ o,
36
+ h0,
37
+ ht,
38
+ offsets,
39
+ scale,
40
+ T,
41
+ B: tl.constexpr,
42
+ H: tl.constexpr,
43
+ K: tl.constexpr,
44
+ V: tl.constexpr,
45
+ BK: tl.constexpr,
46
+ BV: tl.constexpr,
47
+ REVERSE: tl.constexpr,
48
+ USE_G: tl.constexpr,
49
+ USE_GK: tl.constexpr,
50
+ USE_GV: tl.constexpr,
51
+ USE_INITIAL_STATE: tl.constexpr,
52
+ STORE_FINAL_STATE: tl.constexpr,
53
+ USE_OFFSETS: tl.constexpr,
54
+ HEAD_FIRST: tl.constexpr
55
+ ):
56
+ # indices
57
+ i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64)
58
+ i_n, i_h = i_nh // H, i_nh % H
59
+ if USE_OFFSETS:
60
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
61
+ all = T
62
+ T = eos - bos
63
+ else:
64
+ bos, eos = i_n * T, i_n * T + T
65
+ all = B * T
66
+
67
+ if HEAD_FIRST:
68
+ p_q = q + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
69
+ p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
70
+ p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
71
+ p_o = o + (i_k * B*H + i_nh) * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
72
+ if USE_G:
73
+ p_g = g + i_nh * T + ((T-1) if REVERSE else 0)
74
+ if USE_GK:
75
+ p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
76
+ if USE_GV:
77
+ p_gv = gv + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
78
+ else:
79
+ p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
80
+ p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
81
+ p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
82
+ p_o = o + ((i_k * all + bos) + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
83
+ if USE_G:
84
+ p_g = g + (bos + ((T-1) if REVERSE else 0)) * H + i_h
85
+ if USE_GK:
86
+ p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
87
+ if USE_GV:
88
+ p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
89
+
90
+ mask_k = (i_k * BK + tl.arange(0, BK)) < K
91
+ mask_v = (i_v * BV + tl.arange(0, BV)) < V
92
+ mask_h = mask_k[None, :] & mask_v[:, None]
93
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
94
+
95
+ if USE_INITIAL_STATE:
96
+ p_h0 = h0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
97
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
98
+
99
+ for _ in range(0, T):
100
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
101
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
102
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
103
+ if USE_GK:
104
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
105
+ b_h = b_h * exp(b_gk[None, :])
106
+ if USE_GV:
107
+ b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
108
+ b_h = b_h * exp(b_gv[:, None])
109
+ if USE_G:
110
+ b_g = tl.load(p_g).to(tl.float32)
111
+ b_h = b_h * exp(b_g)
112
+ b_h += b_k[None, :] * b_v[:, None]
113
+ b_o = b_h * b_q[None, :]
114
+ b_o = tl.sum(b_o, axis=1)
115
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
116
+ p_q += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
117
+ p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
118
+ p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
119
+ p_o += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
120
+ if USE_GK:
121
+ p_gk += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
122
+ if USE_GV:
123
+ p_gv += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
124
+ if USE_G:
125
+ p_g += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H)
126
+
127
+ if STORE_FINAL_STATE:
128
+ p_ht = ht + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
129
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
130
+
131
+
132
+ @triton.heuristics({
133
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
134
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
135
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
136
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
137
+ })
138
+ @triton.autotune(
139
+ configs=[
140
+ triton.Config({}, num_warps=num_warps)
141
+ for num_warps in [1, 2, 4]
142
+ ],
143
+ key=['BK', 'BV', 'USE_GK', 'USE_GV', 'USE_G'],
144
+ )
145
+ @triton.jit(do_not_specialize=['T'])
146
+ def fused_recurrent_bwd_kernel(
147
+ q,
148
+ k,
149
+ v,
150
+ g,
151
+ gk,
152
+ gv,
153
+ h0,
154
+ do,
155
+ dq,
156
+ dk,
157
+ dv,
158
+ dht,
159
+ dh0,
160
+ offsets,
161
+ scale,
162
+ T,
163
+ B: tl.constexpr,
164
+ H: tl.constexpr,
165
+ K: tl.constexpr,
166
+ V: tl.constexpr,
167
+ BK: tl.constexpr,
168
+ BV: tl.constexpr,
169
+ REVERSE: tl.constexpr,
170
+ USE_G: tl.constexpr,
171
+ USE_GK: tl.constexpr,
172
+ USE_GV: tl.constexpr,
173
+ USE_INITIAL_STATE: tl.constexpr,
174
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
175
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
176
+ USE_OFFSETS: tl.constexpr,
177
+ HEAD_FIRST: tl.constexpr
178
+ ):
179
+ i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64)
180
+ i_n, i_h = i_nh // H, i_nh % H
181
+ if USE_OFFSETS:
182
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
183
+ all = T
184
+ T = eos - bos
185
+ else:
186
+ bos, eos = i_n * T, i_n * T + T
187
+ all = B * T
188
+
189
+ if HEAD_FIRST:
190
+ p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
191
+ p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
192
+ p_do = do + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
193
+ p_dq = dq + (i_v * B*H + i_nh) * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
194
+ if USE_G:
195
+ p_g = g + i_nh * T + ((T-1) if REVERSE else 0)
196
+ if USE_GK:
197
+ p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
198
+ if USE_GV:
199
+ p_gv = gv + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
200
+ else:
201
+ p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
202
+ p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
203
+ p_do = do + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
204
+ p_dq = dq + ((i_v * all + bos) + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
205
+ if USE_G:
206
+ p_g = g + (bos + ((T-1) if REVERSE else 0)) * H + i_h
207
+ if USE_GK:
208
+ p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
209
+ if USE_GV:
210
+ p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
211
+
212
+ mask_k = i_k * BK + tl.arange(0, BK) < K
213
+ mask_v = i_v * BV + tl.arange(0, BV) < V
214
+ mask_h = mask_k[:, None] & mask_v[None, :]
215
+
216
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
217
+ if USE_INITIAL_STATE:
218
+ p_h0 = h0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
219
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
220
+
221
+ for _ in range(0, T):
222
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
223
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
224
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
225
+ if USE_G:
226
+ b_g = tl.load(p_g).to(tl.float32)
227
+ b_h = b_h * exp(b_g)
228
+ if USE_GK:
229
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
230
+ b_h = b_h * exp(b_gk[:, None])
231
+ if USE_GV:
232
+ b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
233
+ b_h = b_h * exp(b_gv[None, :])
234
+ b_h += b_k[:, None] * b_v[None, :]
235
+ b_dq = b_h * b_do[None, :]
236
+ b_dq = tl.sum(b_dq, axis=1) * scale
237
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_k)
238
+
239
+ p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
240
+ p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
241
+ p_do += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
242
+ p_dq += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
243
+ if USE_G:
244
+ p_g += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H)
245
+ if USE_GK:
246
+ p_gk += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
247
+ if USE_GV:
248
+ p_gv += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
249
+
250
+ # sync threads
251
+ tl.debug_barrier()
252
+
253
+ if HEAD_FIRST:
254
+ p_q = q + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK)
255
+ p_k = k + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK)
256
+ p_v = v + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV)
257
+ p_do = do + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV)
258
+ p_dk = dk + (i_v * B*H + i_nh) * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK)
259
+ p_dv = dv + (i_k * B*H + i_nh) * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV)
260
+ if USE_G:
261
+ p_g = g + i_nh * T + ((T - 1) if not REVERSE else 0)
262
+ if USE_GK:
263
+ p_gk = gk + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK)
264
+ if USE_GV:
265
+ p_gv = gv + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV)
266
+ else:
267
+ p_q = q + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
268
+ p_k = k + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
269
+ p_v = v + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
270
+ p_do = do + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
271
+ p_dk = dk + ((i_v * all + bos) + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
272
+ p_dv = dv + ((i_k * all + bos) + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
273
+ if USE_G:
274
+ p_g = g + (bos + ((T - 1) if not REVERSE else 0)) * H + i_h
275
+ if USE_GK:
276
+ p_gk = gk + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
277
+ if USE_GV:
278
+ p_gv = gv + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
279
+
280
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
281
+ if USE_FINAL_STATE_GRADIENT:
282
+ p_dht = dht + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
283
+ b_dh += tl.load(p_dht, mask=mask_h, other=0).to(tl.float32)
284
+
285
+ for _ in range(T):
286
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
287
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
288
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
289
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
290
+ b_dh += b_q[:, None] * b_do[None, :]
291
+ b_dk = tl.sum(b_dh * b_v[None, :], axis=1)
292
+ b_dv = tl.sum(b_dh * b_k[:, None], axis=0)
293
+ if USE_G:
294
+ b_g = tl.load(p_g).to(tl.float32)
295
+ b_dh *= exp(b_g)
296
+ if USE_GK:
297
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
298
+ b_dh *= exp(b_gk)[:, None]
299
+ if USE_GV:
300
+ b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
301
+ b_dh *= exp(b_gv)[None, :]
302
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k)
303
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_v)
304
+
305
+ p_q += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K
306
+ p_k += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K
307
+ p_v += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V
308
+ p_do += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V
309
+ p_dk += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K
310
+ p_dv += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V
311
+ if USE_G:
312
+ p_g += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H)
313
+ if USE_GK:
314
+ p_gk += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K
315
+ if USE_GV:
316
+ p_gv += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V
317
+
318
+ if STORE_INITIAL_STATE_GRADIENT:
319
+ p_dh0 = dh0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
320
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_h)
321
+
322
+
323
+ def fused_recurrent_fwd(
324
+ q: torch.Tensor,
325
+ k: torch.Tensor,
326
+ v: torch.Tensor,
327
+ g: Optional[torch.Tensor] = None,
328
+ gk: Optional[torch.Tensor] = None,
329
+ gv: Optional[torch.Tensor] = None,
330
+ scale: Optional[float] = None,
331
+ initial_state: Optional[torch.Tensor] = None,
332
+ output_final_state: bool = False,
333
+ reverse: bool = False,
334
+ offsets: Optional[torch.LongTensor] = None,
335
+ head_first: bool = True
336
+ ):
337
+ if head_first:
338
+ B, H, T, K, V = *k.shape, v.shape[-1]
339
+ else:
340
+ B, T, H, K, V = *k.shape, v.shape[-1]
341
+ N = B if offsets is None else len(offsets) - 1
342
+ BK, BV = min(K, 64), min(V, 64)
343
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
344
+
345
+ h0 = initial_state
346
+ if output_final_state:
347
+ ht = q.new_empty(N, H, K, V, dtype=torch.float32)
348
+ else:
349
+ ht = None
350
+ o = q.new_empty(NK, *v.shape, dtype=torch.float32)
351
+
352
+ grid = (NV, NK, N * H)
353
+ fused_recurrent_fwd_kernel[grid](
354
+ q,
355
+ k,
356
+ v,
357
+ g,
358
+ gk,
359
+ gv,
360
+ o,
361
+ h0,
362
+ ht,
363
+ offsets,
364
+ scale,
365
+ T=T,
366
+ B=B,
367
+ H=H,
368
+ K=K,
369
+ V=V,
370
+ BK=BK,
371
+ BV=BV,
372
+ USE_G=g is not None,
373
+ USE_GK=gk is not None,
374
+ USE_GV=gv is not None,
375
+ REVERSE=reverse,
376
+ HEAD_FIRST=head_first
377
+ )
378
+ o = o.sum(0)
379
+ return o, ht
380
+
381
+
382
+ def fused_recurrent_bwd(
383
+ q: torch.Tensor,
384
+ k: torch.Tensor,
385
+ v: torch.Tensor,
386
+ g: Optional[torch.Tensor] = None,
387
+ gk: Optional[torch.Tensor] = None,
388
+ gv: Optional[torch.Tensor] = None,
389
+ o: Optional[torch.Tensor] = None,
390
+ do: Optional[torch.Tensor] = None,
391
+ dht: Optional[torch.Tensor] = None,
392
+ scale: Optional[float] = None,
393
+ initial_state: Optional[torch.Tensor] = None,
394
+ reverse: bool = False,
395
+ offsets: Optional[torch.LongTensor] = None,
396
+ head_first: bool = True
397
+ ):
398
+ if head_first:
399
+ B, H, T, K, V = *k.shape, v.shape[-1]
400
+ else:
401
+ B, T, H, K, V = *k.shape, v.shape[-1]
402
+ N = B if offsets is None else len(offsets) - 1
403
+
404
+ BK, BV = min(K, 64), min(V, 64)
405
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
406
+
407
+ dq = q.new_empty(NV, *q.shape, dtype=torch.float32)
408
+ dk = q.new_empty(NV, *k.shape, dtype=torch.float32)
409
+ dv = q.new_empty(NK, *v.shape, dtype=torch.float32)
410
+ h0 = initial_state
411
+ dh0 = torch.empty_like(initial_state) if initial_state is not None else None
412
+
413
+ grid = (NV, NK, N * H)
414
+ fused_recurrent_bwd_kernel[grid](
415
+ q,
416
+ k,
417
+ v,
418
+ g,
419
+ gk,
420
+ gv,
421
+ h0,
422
+ do,
423
+ dq,
424
+ dk,
425
+ dv,
426
+ dht,
427
+ dh0,
428
+ offsets,
429
+ scale,
430
+ B=B,
431
+ T=T,
432
+ H=H,
433
+ K=K,
434
+ V=V,
435
+ BK=BK,
436
+ BV=BV,
437
+ USE_G=g is not None,
438
+ USE_GK=gk is not None,
439
+ USE_GV=gv is not None,
440
+ REVERSE=reverse,
441
+ HEAD_FIRST=head_first
442
+ )
443
+ dq = dq.sum(0)
444
+ dk = dk.sum(0)
445
+ dv = dv.sum(0)
446
+ dg, dgk, dgv = None, None, None
447
+ if g is not None:
448
+ dg = chunk_global_cumsum(
449
+ (dq * q.float() - dk * k.float()).sum(-1),
450
+ reverse=not reverse,
451
+ offsets=offsets,
452
+ head_first=head_first
453
+ )
454
+ if gk is not None:
455
+ dgk = chunk_global_cumsum(
456
+ dq * q.float() - dk * k.float(),
457
+ reverse=not reverse,
458
+ offsets=offsets,
459
+ head_first=head_first
460
+ )
461
+ if gv is not None:
462
+ dgv = chunk_global_cumsum(
463
+ do.float() * o.float() - dv * v.float(),
464
+ reverse=not reverse,
465
+ offsets=offsets,
466
+ head_first=head_first
467
+ )
468
+
469
+ return dq, dk, dv, dg, dgk, dgv, dh0
470
+
471
+
472
+ class FusedRecurrentFunction(torch.autograd.Function):
473
+
474
+ @staticmethod
475
+ @input_guard
476
+ @autocast_custom_fwd
477
+ def forward(
478
+ ctx,
479
+ q: torch.Tensor,
480
+ k: torch.Tensor,
481
+ v: torch.Tensor,
482
+ g: Optional[torch.Tensor] = None,
483
+ gk: Optional[torch.Tensor] = None,
484
+ gv: Optional[torch.Tensor] = None,
485
+ scale: Optional[float] = None,
486
+ initial_state: Optional[torch.Tensor] = None,
487
+ output_final_state: bool = False,
488
+ reverse: bool = False,
489
+ offsets: Optional[torch.LongTensor] = None,
490
+ head_first: bool = True
491
+ ):
492
+ o, ht = fused_recurrent_fwd(
493
+ q=q,
494
+ k=k,
495
+ v=v,
496
+ g=g,
497
+ gk=gk,
498
+ gv=gv,
499
+ scale=scale,
500
+ initial_state=initial_state,
501
+ output_final_state=output_final_state,
502
+ reverse=reverse,
503
+ offsets=offsets,
504
+ head_first=head_first
505
+ )
506
+ ctx.save_for_backward(q, k, v, g, gk, gv, initial_state, o)
507
+ ctx.scale = scale
508
+ ctx.reverse = reverse
509
+ ctx.offsets = offsets
510
+ ctx.head_first = head_first
511
+ return o.to(q.dtype), ht
512
+
513
+ @staticmethod
514
+ @input_guard
515
+ @autocast_custom_bwd
516
+ def backward(ctx, do, dht):
517
+ q, k, v, g, gk, gv, initial_state, o = ctx.saved_tensors
518
+ # not supported yet.
519
+ if dht is not None:
520
+ if not dht.eq(0).all():
521
+ if g is not None:
522
+ assert g.requires_grad is False, "Cannot load final state gradient and use gates at the same time"
523
+ if gk is not None:
524
+ assert gk.requires_grad is False, "Cannot load final state gradient and use gates at the same time"
525
+ if gv is not None:
526
+ assert gv.requires_grad is False, "Cannot load final state gradient and use gates at the same time"
527
+ dq, dk, dv, dg, dgk, dgv, dh0 = fused_recurrent_bwd(
528
+ q=q,
529
+ k=k,
530
+ v=v,
531
+ g=g,
532
+ gk=gk,
533
+ gv=gv,
534
+ o=o,
535
+ do=do,
536
+ dht=dht,
537
+ scale=ctx.scale,
538
+ initial_state=initial_state,
539
+ reverse=ctx.reverse,
540
+ offsets=ctx.offsets,
541
+ head_first=ctx.head_first
542
+ )
543
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg, dgk, dgv, None, dh0, None, None, None, None
544
+
545
+
546
+ def fused_recurrent(
547
+ q: torch.Tensor,
548
+ k: torch.Tensor,
549
+ v: torch.Tensor,
550
+ g: Optional[torch.Tensor] = None,
551
+ gk: Optional[torch.Tensor] = None,
552
+ gv: Optional[torch.Tensor] = None,
553
+ scale: Optional[float] = None,
554
+ initial_state: Optional[torch.Tensor] = None,
555
+ output_final_state: bool = False,
556
+ reverse: bool = False,
557
+ cu_seqlens: Optional[torch.LongTensor] = None,
558
+ head_first: bool = True
559
+ ):
560
+ if scale is None:
561
+ scale = k.shape[-1] ** -0.5
562
+ return FusedRecurrentFunction.apply(
563
+ q,
564
+ k,
565
+ v,
566
+ g,
567
+ gk,
568
+ gv,
569
+ scale,
570
+ initial_state,
571
+ output_final_state,
572
+ reverse,
573
+ cu_seqlens,
574
+ head_first
575
+ )
fla/ops/delta_rule/fused_chunk.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ def fused_chunk_delta_rule(
4
+ **kwargs
5
+ ):
6
+ raise NotImplementedError("fused_chunk_delta_rule is deprecated. Please use chunk_delta_rule instead.")
fla/ops/gated_delta_rule/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .chunk import chunk_gated_delta_rule
2
+ from .fused_recurrent import fused_recurrent_gated_delta_rule
3
+
4
+ __all__ = [
5
+ "chunk_gated_delta_rule",
6
+ "fused_recurrent_gated_delta_rule"
7
+ ]
fla/ops/gated_delta_rule/fused_recurrent.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange
10
+
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import input_guard
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.jit(do_not_specialize=['T'])
21
+ def fused_recurrent_gated_delta_rule_fwd_kernel(
22
+ q,
23
+ k,
24
+ v,
25
+ g,
26
+ beta,
27
+ o,
28
+ h0,
29
+ ht,
30
+ offsets,
31
+ scale,
32
+ T,
33
+ B: tl.constexpr,
34
+ H: tl.constexpr,
35
+ K: tl.constexpr,
36
+ V: tl.constexpr,
37
+ BK: tl.constexpr,
38
+ BV: tl.constexpr,
39
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
40
+ STORE_FINAL_STATE: tl.constexpr, # whether to store final state
41
+ IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
42
+ USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
43
+ USE_OFFSETS: tl.constexpr
44
+ ):
45
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
46
+ i_n, i_h = i_nh // H, i_nh % H
47
+ if USE_OFFSETS:
48
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
49
+ all = T
50
+ T = eos - bos
51
+ else:
52
+ bos, eos = i_n * T, i_n * T + T
53
+ all = B * T
54
+ o_k = i_k * BK + tl.arange(0, BK)
55
+ o_v = i_v * BV + tl.arange(0, BV)
56
+
57
+ p_q = q + (bos * H + i_h) * K + o_k
58
+ p_k = k + (bos * H + i_h) * K + o_k
59
+ p_v = v + (bos * H + i_h) * V + o_v
60
+ if IS_BETA_HEADWISE:
61
+ p_beta = beta + (bos * H + i_h) * V + o_v
62
+ else:
63
+ p_beta = beta + bos * H + i_h
64
+ p_g = g + bos * H + i_h
65
+ p_o = o + ((i_k * all + bos) * H + i_h) * V + o_v
66
+
67
+ mask_k = o_k < K
68
+ mask_v = o_v < V
69
+ mask_h = mask_k[:, None] & mask_v[None, :]
70
+
71
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
72
+ if USE_INITIAL_STATE:
73
+ p_h0 = h0 + i_nh * K*V + o_k[:, None] * V + o_v[None, :]
74
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
75
+
76
+ for _ in range(0, T):
77
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
78
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
79
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
80
+ b_g = tl.load(p_g).to(tl.float32)
81
+
82
+ if USE_QK_L2NORM_IN_KERNEL:
83
+ b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
84
+ b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
85
+ b_q = b_q * scale
86
+ # [BK, BV]
87
+ b_h *= exp(b_g)
88
+ # [BV]
89
+ b_v -= tl.sum(b_h * b_k[:, None], 0)
90
+ if IS_BETA_HEADWISE:
91
+ b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
92
+ else:
93
+ b_beta = tl.load(p_beta).to(tl.float32)
94
+ b_v *= b_beta
95
+ # [BK, BV]
96
+ b_h += b_k[:, None] * b_v[None, :]
97
+ # [BV]
98
+ b_o = tl.sum(b_h * b_q[:, None], 0)
99
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
100
+
101
+ p_q += H*K
102
+ p_k += H*K
103
+ p_o += H*V
104
+ p_v += H*V
105
+ p_g += H
106
+ p_beta += H * (V if IS_BETA_HEADWISE else 1)
107
+
108
+ if STORE_FINAL_STATE:
109
+ p_ht = ht + i_nh * K*V + o_k[:, None] * V + o_v[None, :]
110
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
111
+
112
+
113
+ def fused_recurrent_gated_delta_rule_fwd(
114
+ q: torch.Tensor,
115
+ k: torch.Tensor,
116
+ v: torch.Tensor,
117
+ g: torch.Tensor,
118
+ beta: torch.Tensor,
119
+ scale: float,
120
+ initial_state: torch.Tensor,
121
+ output_final_state: bool,
122
+ use_qk_l2norm_in_kernel: bool = False,
123
+ offsets: Optional[torch.LongTensor] = None,
124
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
125
+ B, T, H, K, V = *k.shape, v.shape[-1]
126
+ N = B if offsets is None else len(offsets) - 1
127
+ BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
128
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
129
+ assert NK == 1, "NK > 1 is not supported yet"
130
+ num_stages = 3
131
+ num_warps = 1
132
+
133
+ o = q.new_empty(NK, *v.shape)
134
+ if output_final_state:
135
+ final_state = q.new_empty(N, H, K, V, dtype=torch.float32)
136
+ else:
137
+ final_state = None
138
+
139
+ grid = (NK, NV, N * H)
140
+ fused_recurrent_gated_delta_rule_fwd_kernel[grid](
141
+ q=q,
142
+ k=k,
143
+ v=v,
144
+ g=g,
145
+ beta=beta,
146
+ o=o,
147
+ h0=initial_state,
148
+ ht=final_state,
149
+ offsets=offsets,
150
+ scale=scale,
151
+ T=T,
152
+ B=B,
153
+ H=H,
154
+ K=K,
155
+ V=V,
156
+ BK=BK,
157
+ BV=BV,
158
+ IS_BETA_HEADWISE=beta.ndim == v.ndim,
159
+ USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
160
+ num_warps=num_warps,
161
+ num_stages=num_stages,
162
+ )
163
+ o = o.squeeze(0)
164
+ return o, final_state
165
+
166
+
167
+ class FusedRecurrentFunction(torch.autograd.Function):
168
+
169
+ @staticmethod
170
+ @input_guard
171
+ def forward(
172
+ ctx,
173
+ q: torch.Tensor,
174
+ k: torch.Tensor,
175
+ v: torch.Tensor,
176
+ g: torch.Tensor,
177
+ beta: torch.Tensor,
178
+ scale: float,
179
+ initial_state: torch.Tensor,
180
+ output_final_state: bool,
181
+ offsets: Optional[torch.LongTensor] = None,
182
+ use_qk_l2norm_in_kernel: bool = False
183
+ ):
184
+ o, final_state = fused_recurrent_gated_delta_rule_fwd(
185
+ q=q,
186
+ k=k,
187
+ v=v,
188
+ g=g,
189
+ beta=beta,
190
+ scale=scale,
191
+ initial_state=initial_state,
192
+ output_final_state=output_final_state,
193
+ use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
194
+ offsets=offsets
195
+ )
196
+
197
+ return o, final_state
198
+
199
+ @staticmethod
200
+ @input_guard
201
+ def backward(ctx, do, dht):
202
+ raise NotImplementedError(
203
+ "Backward pass is not implemented yet and we do not have plans to implement it "
204
+ "because we haven't figured out how to compute dg without materializing the full "
205
+ "hidden states for all time steps."
206
+ )
207
+
208
+
209
+ def fused_recurrent_gated_delta_rule(
210
+ q: torch.Tensor,
211
+ k: torch.Tensor,
212
+ v: torch.Tensor,
213
+ g: torch.Tensor,
214
+ beta: torch.Tensor = None,
215
+ scale: float = None,
216
+ initial_state: torch.Tensor = None,
217
+ output_final_state: bool = False,
218
+ cu_seqlens: Optional[torch.LongTensor] = None,
219
+ use_qk_l2norm_in_kernel: bool = False,
220
+ head_first: bool = False,
221
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
222
+ r"""
223
+ Args:
224
+ q (torch.Tensor):
225
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
226
+ k (torch.Tensor):
227
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
228
+ v (torch.Tensor):
229
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
230
+ g (torch.Tensor):
231
+ g (decays) of shape `[B, T, H]` if `head_first=False` else `(B, H, T)`.
232
+ beta (torch.Tensor):
233
+ betas of shape `[B, T, H]` if `head_first=False` else `(B, H, T)`.
234
+ scale (Optional[int]):
235
+ Scale factor for the RetNet attention scores.
236
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
237
+ initial_state (Optional[torch.Tensor]):
238
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
239
+ For equal-length input sequences, `N` equals the batch size `B`.
240
+ Default: `None`.
241
+ output_final_state (Optional[bool]):
242
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
243
+ cu_seqlens (torch.LongTensor):
244
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
245
+ consistent with the FlashAttention API.
246
+
247
+ Returns:
248
+ o (torch.Tensor):
249
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
250
+ final_state (torch.Tensor):
251
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
252
+
253
+ Examples::
254
+ >>> import torch
255
+ >>> import torch.nn.functional as F
256
+ >>> from einops import rearrange
257
+ >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
258
+ # inputs with equal lengths
259
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
260
+ >>> q = torch.randn(B, T, H, K, device='cuda')
261
+ >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
262
+ >>> v = torch.randn(B, T, H, V, device='cuda')
263
+ >>> g = F.logsigmoid(torch.rand(B, T, H, device='cuda'))
264
+ >>> beta = torch.rand(B, T, H, device='cuda').sigmoid()
265
+ >>> h0 = torch.randn(B, H, K, V, device='cuda')
266
+ >>> o, ht = fused_gated_recurrent_delta_rule(
267
+ q, k, v, g, beta,
268
+ initial_state=h0,
269
+ output_final_state=True,
270
+ )
271
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
272
+ >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
273
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
274
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
275
+ >>> o_var, ht_var = fused_gated_recurrent_delta_rule(
276
+ q, k, v, g, beta,
277
+ initial_state=h0,
278
+ output_final_state=True,
279
+ cu_seqlens=cu_seqlens
280
+ )
281
+ >>> assert o.allclose(o_var.view(o.shape))
282
+ >>> assert ht.allclose(ht_var)
283
+ """
284
+ if cu_seqlens is not None:
285
+ if q.shape[0] != 1:
286
+ raise ValueError(
287
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
288
+ f"Please flatten variable-length inputs before processing."
289
+ )
290
+ if head_first:
291
+ raise RuntimeError(
292
+ "Sequences with variable lengths are not supported for head-first mode"
293
+ )
294
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
295
+ raise ValueError(
296
+ f"The number of initial states is expected to be equal to the number of input sequences, "
297
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
298
+ )
299
+ if scale is None:
300
+ scale = k.shape[-1] ** -0.5
301
+ else:
302
+ assert scale > 0, "scale must be positive"
303
+ if beta is None:
304
+ beta = torch.ones_like(q[..., 0])
305
+ if head_first:
306
+ q, k, v, g, beta = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, g, beta))
307
+ o, final_state = FusedRecurrentFunction.apply(
308
+ q,
309
+ k,
310
+ v,
311
+ g,
312
+ beta,
313
+ scale,
314
+ initial_state,
315
+ output_final_state,
316
+ cu_seqlens,
317
+ use_qk_l2norm_in_kernel
318
+ )
319
+ if head_first:
320
+ o = rearrange(o, 'b t h v -> b h t v')
321
+ return o, final_state
fla/ops/gated_delta_rule/wy_fast.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import safe_exp
11
+ from fla.utils import check_shared_mem
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
20
+ for num_warps in [2, 4, 8]
21
+ for num_stages in [2, 3, 4]
22
+ ],
23
+ key=['H', 'K', 'BT', 'BK', 'BC', 'HEAD_FIRST', 'USE_OFFSETS'],
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def fwd_prepare_wy_repr_kernel_chunk32(
27
+ k,
28
+ g,
29
+ beta,
30
+ Aw,
31
+ Au,
32
+ offsets,
33
+ indices,
34
+ T,
35
+ H: tl.constexpr,
36
+ K: tl.constexpr,
37
+ BT: tl.constexpr,
38
+ BK: tl.constexpr,
39
+ BC: tl.constexpr,
40
+ HEAD_FIRST: tl.constexpr,
41
+ USE_OFFSETS: tl.constexpr
42
+ ):
43
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
44
+ i_b, i_h = i_bh // H, i_bh % H
45
+ if USE_OFFSETS:
46
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
47
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
48
+ T = eos - bos
49
+ else:
50
+ bos, eos = i_b * T, i_b * T + T
51
+
52
+ b_Aw = tl.zeros([BC, BC], dtype=tl.float32)
53
+ if HEAD_FIRST:
54
+ p_beta = tl.make_block_ptr(beta + i_bh*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
55
+ else:
56
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
57
+
58
+ b_beta = tl.load(p_beta, boundary_check=(0,))
59
+
60
+ for i_k in range(tl.cdiv(K, BK)):
61
+ if HEAD_FIRST:
62
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
63
+ else:
64
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
65
+ b_k = tl.load(p_k, boundary_check=(0, 1))
66
+ b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
67
+ b_Aw += tl.dot(b_kb, tl.trans(b_k))
68
+
69
+ b_Aw = -tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_Aw, 0)
70
+
71
+ if HEAD_FIRST:
72
+ p_g = tl.make_block_ptr(g + i_bh*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
73
+ else:
74
+ p_g = tl.make_block_ptr(g + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
75
+
76
+ b_g = tl.load(p_g, boundary_check=(0,))
77
+ b_Au = b_Aw * safe_exp(b_g[:, None] - b_g[None, :])
78
+
79
+ for i in range(1, BC):
80
+ mask = tl.arange(0, BC) == i
81
+ b_aw = tl.sum(tl.where(mask[:, None], b_Aw, 0), 0)
82
+ b_au = tl.sum(tl.where(mask[:, None], b_Au, 0), 0)
83
+ b_aw = b_aw + tl.sum(b_aw[:, None] * b_Aw, 0) * (tl.arange(0, BC) < i)
84
+ b_au = b_au + tl.sum(b_au[:, None] * b_Au, 0) * (tl.arange(0, BC) < i)
85
+ b_Aw = tl.where(mask[:, None], b_aw, b_Aw)
86
+ b_Au = tl.where(mask[:, None], b_au, b_Au)
87
+
88
+ # blockwise computation of lower triangular matrix's inverse
89
+ # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
90
+ b_Aw += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
91
+ b_Au += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
92
+ if HEAD_FIRST:
93
+ p_Aw = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
94
+ p_Au = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
95
+ else:
96
+ p_Aw = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
97
+ p_Au = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
98
+ tl.store(p_Aw, b_Aw.to(p_Aw.dtype.element_ty), boundary_check=(0, 1))
99
+ tl.store(p_Au, b_Au.to(p_Au.dtype.element_ty), boundary_check=(0, 1))
100
+
101
+
102
+ @triton.heuristics({
103
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
104
+ })
105
+ @triton.autotune(
106
+ configs=[
107
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
108
+ for num_warps in [2, 4, 8]
109
+ for num_stages in [2, 3, 4]
110
+ ],
111
+ key=['H', 'K', 'BT', 'BK', 'BC', 'USE_OFFSETS', 'HEAD_FIRST'],
112
+ )
113
+ @triton.jit(do_not_specialize=['T'])
114
+ def fwd_prepare_wy_repr_kernel_chunk64(
115
+ k,
116
+ g,
117
+ beta,
118
+ Aw,
119
+ Au,
120
+ offsets,
121
+ indices,
122
+ T,
123
+ H: tl.constexpr,
124
+ K: tl.constexpr,
125
+ BT: tl.constexpr,
126
+ BK: tl.constexpr,
127
+ BC: tl.constexpr,
128
+ USE_OFFSETS: tl.constexpr,
129
+ HEAD_FIRST: tl.constexpr
130
+ ):
131
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
132
+ i_b, i_h = i_bh // H, i_bh % H
133
+ if USE_OFFSETS:
134
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
135
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
136
+ T = eos - bos
137
+ else:
138
+ bos, eos = i_b * T, i_b * T + T
139
+
140
+ b_Aw = tl.zeros([BC, BC], dtype=tl.float32)
141
+ b_Aw2 = tl.zeros([BC, BC], dtype=tl.float32)
142
+ b_Aw3 = tl.zeros([BC, BC], dtype=tl.float32)
143
+ if HEAD_FIRST:
144
+ p_beta = tl.make_block_ptr(beta + i_bh*T, (T,), (1,), (i_t * BT,), (BC,), (0,))
145
+ p_beta2 = tl.make_block_ptr(beta + i_bh*T, (T,), (1,), (i_t * BT + BC,), (BC,), (0,))
146
+ else:
147
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BC,), (0,))
148
+ p_beta2 = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT + BC,), (BC,), (0,))
149
+
150
+ b_beta = tl.load(p_beta, boundary_check=(0,))
151
+ b_beta2 = tl.load(p_beta2, boundary_check=(0,))
152
+
153
+ for i_k in range(tl.cdiv(K, BK)):
154
+ if HEAD_FIRST:
155
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
156
+ p_k2 = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0))
157
+ else:
158
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
159
+ p_k2 = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0))
160
+ b_k = tl.load(p_k, boundary_check=(0, 1))
161
+ b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
162
+ b_k2 = tl.load(p_k2, boundary_check=(0, 1))
163
+ b_kb2 = (b_k2 * b_beta2[:, None]).to(b_k2.dtype)
164
+ b_Aw += tl.dot(b_kb, tl.trans(b_k))
165
+ b_Aw2 += tl.dot(b_kb2, tl.trans(b_k2))
166
+ b_Aw3 += tl.dot(b_kb2, tl.trans(b_k))
167
+
168
+ b_Aw = -tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_Aw, 0)
169
+ b_Aw2 = -tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_Aw2, 0)
170
+
171
+ if HEAD_FIRST:
172
+ p_g = tl.make_block_ptr(g + i_bh*T, (T,), (1,), (i_t * BT,), (BC,), (0,))
173
+ p_g2 = tl.make_block_ptr(g + i_bh*T, (T,), (1,), (i_t * BT + BC,), (BC,), (0,))
174
+ else:
175
+ p_g = tl.make_block_ptr(g + bos*H + i_h, (T,), (H,), (i_t * BT,), (BC,), (0,))
176
+ p_g2 = tl.make_block_ptr(g + bos*H + i_h, (T,), (H,), (i_t * BT + BC,), (BC,), (0,))
177
+ b_g = tl.load(p_g, boundary_check=(0,))
178
+ b_g2 = tl.load(p_g2, boundary_check=(0,))
179
+
180
+ mask_c = tl.arange(0, BC)[:, None] >= tl.arange(0, BC)[None, :]
181
+ mask_g = i_t * BT + tl.arange(0, BC) < T
182
+ mask_g2 = i_t * BT + BC + tl.arange(0, BC) < T
183
+
184
+ b_Au = tl.where(mask_g[None, :] & mask_c, b_Aw * safe_exp(b_g[:, None] - b_g[None, :]), 0)
185
+ b_Au2 = tl.where(mask_g2[None, :] & mask_c, b_Aw2 * safe_exp(b_g2[:, None] - b_g2[None, :]), 0)
186
+ b_Au3 = tl.where(mask_g[None, :], b_Aw3 * safe_exp(b_g2[:, None] - b_g[None, :]), 0)
187
+
188
+ for i in range(1, BC):
189
+ mask = tl.arange(0, BC) == i
190
+ b_aw = tl.sum(tl.where(mask[:, None], b_Aw, 0), 0)
191
+ b_aw2 = tl.sum(tl.where(mask[:, None], b_Aw2, 0), 0)
192
+ b_au = tl.sum(tl.where(mask[:, None], b_Au, 0), 0)
193
+ b_au2 = tl.sum(tl.where(mask[:, None], b_Au2, 0), 0)
194
+ b_aw = b_aw + tl.sum(b_aw[:, None] * b_Aw, 0) * (tl.arange(0, BC) < i)
195
+ b_aw2 = b_aw2 + tl.sum(b_aw2[:, None] * b_Aw2, 0) * (tl.arange(0, BC) < i)
196
+ b_au = b_au + tl.sum(b_au[:, None] * b_Au, 0) * (tl.arange(0, BC) < i)
197
+ b_au2 = b_au2 + tl.sum(b_au2[:, None] * b_Au2, 0) * (tl.arange(0, BC) < i)
198
+ b_Aw = tl.where(mask[:, None], b_aw, b_Aw)
199
+ b_Aw2 = tl.where(mask[:, None], b_aw2, b_Aw2)
200
+ b_Au = tl.where(mask[:, None], b_au, b_Au)
201
+ b_Au2 = tl.where(mask[:, None], b_au2, b_Au2)
202
+ # blockwise computation of lower triangular matrix's inverse
203
+ # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
204
+ b_Aw += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
205
+ b_Aw2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
206
+ # improve precision by disallowing tf32.
207
+ b_Aw3 = -tl.dot(tl.dot(b_Aw2, b_Aw3, allow_tf32=False), b_Aw, allow_tf32=False)
208
+ b_Au += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
209
+ b_Au2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
210
+ b_Au3 = -tl.dot(tl.dot(b_Au2, b_Au3, allow_tf32=False), b_Au, allow_tf32=False)
211
+
212
+ if HEAD_FIRST:
213
+ p_Aw1 = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
214
+ p_Aw2 = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
215
+ p_Aw3 = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
216
+ p_Aw4 = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
217
+ p_Au1 = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
218
+ p_Au2 = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
219
+ p_Au3 = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
220
+ p_Au4 = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
221
+ else:
222
+ p_Aw1 = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
223
+ p_Aw2 = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
224
+ p_Aw3 = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
225
+ p_Aw4 = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
226
+ p_Au1 = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
227
+ p_Au2 = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
228
+ p_Au3 = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
229
+ p_Au4 = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
230
+
231
+ tl.store(p_Aw1, b_Aw.to(p_Aw1.dtype.element_ty), boundary_check=(0, 1))
232
+ tl.store(p_Aw2, b_Aw2.to(p_Aw2.dtype.element_ty), boundary_check=(0, 1))
233
+ tl.store(p_Aw3, b_Aw3.to(p_Aw3.dtype.element_ty), boundary_check=(0, 1))
234
+ tl.store(p_Aw4, tl.zeros([BC, BC], dtype=tl.float32).to(p_Aw4.dtype.element_ty), boundary_check=(0, 1))
235
+ tl.store(p_Au1, b_Au.to(p_Au1.dtype.element_ty), boundary_check=(0, 1))
236
+ tl.store(p_Au2, b_Au2.to(p_Au2.dtype.element_ty), boundary_check=(0, 1))
237
+ tl.store(p_Au3, b_Au3.to(p_Au3.dtype.element_ty), boundary_check=(0, 1))
238
+ tl.store(p_Au4, tl.zeros([BC, BC], dtype=tl.float32).to(p_Au4.dtype.element_ty), boundary_check=(0, 1))
239
+
240
+
241
+ @triton.heuristics({
242
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
243
+ })
244
+ @triton.autotune(
245
+ configs=[
246
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
247
+ for num_warps in [2, 4, 8]
248
+ for num_stages in [2, 3, 4]
249
+ ],
250
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'],
251
+ )
252
+ @triton.jit(do_not_specialize=['T'])
253
+ def fwd_recompute_w_u_kernel(
254
+ k,
255
+ v,
256
+ beta,
257
+ w,
258
+ u,
259
+ Aw,
260
+ Au,
261
+ offsets,
262
+ indices,
263
+ T,
264
+ H: tl.constexpr,
265
+ K: tl.constexpr,
266
+ V: tl.constexpr,
267
+ BT: tl.constexpr,
268
+ BK: tl.constexpr,
269
+ BV: tl.constexpr,
270
+ HEAD_FIRST: tl.constexpr,
271
+ USE_OFFSETS: tl.constexpr
272
+ ):
273
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
274
+ i_b, i_h = i_bh // H, i_bh % H
275
+ if USE_OFFSETS:
276
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
277
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
278
+ T = eos - bos
279
+ else:
280
+ bos, eos = i_b * T, i_b * T + T
281
+ if HEAD_FIRST:
282
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
283
+ p_Au = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
284
+ else:
285
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
286
+ p_Au = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
287
+ b_beta = tl.load(p_beta, boundary_check=(0,))
288
+ b_Au = tl.load(p_Au, boundary_check=(0, 1))
289
+
290
+ for i_v in range(tl.cdiv(V, BV)):
291
+ if HEAD_FIRST:
292
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
293
+ p_u = tl.make_block_ptr(u + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
294
+ else:
295
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
296
+ p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
297
+ b_v = tl.load(p_v, boundary_check=(0, 1))
298
+ b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
299
+ b_u = tl.dot(b_Au, b_vb, allow_tf32=False)
300
+ tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
301
+
302
+ tl.debug_barrier()
303
+ b_Au = None
304
+ if HEAD_FIRST:
305
+ p_Aw = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
306
+ else:
307
+ p_Aw = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
308
+ b_Aw = tl.load(p_Aw, boundary_check=(0, 1))
309
+
310
+ for i_k in range(tl.cdiv(K, BK)):
311
+ if HEAD_FIRST:
312
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
313
+ p_w = tl.make_block_ptr(w + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
314
+ else:
315
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
316
+ p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
317
+ b_k = tl.load(p_k, boundary_check=(0, 1))
318
+ b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
319
+ b_w = tl.dot(b_Aw, b_kb)
320
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
321
+
322
+
323
+ def fwd_prepare_wy_repr(
324
+ k: torch.Tensor,
325
+ v: torch.Tensor,
326
+ g: torch.Tensor,
327
+ beta: torch.Tensor,
328
+ offsets: Optional[torch.LongTensor],
329
+ indices: Optional[torch.LongTensor],
330
+ head_first: bool = True,
331
+ chunk_size: int = 64
332
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
333
+ if head_first:
334
+ B, H, T, K = k.shape
335
+ else:
336
+ B, T, H, K = k.shape
337
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
338
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
339
+ BC = min(BT, 32)
340
+ BK = min(triton.next_power_of_2(K), 64)
341
+ # bf16 should be good enough.
342
+ Aw = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=k.device, dtype=k.dtype)
343
+ Au = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=k.device, dtype=k.dtype)
344
+
345
+ fwd_fn = fwd_prepare_wy_repr_kernel_chunk64 if BT == 64 else fwd_prepare_wy_repr_kernel_chunk32
346
+ fwd_fn[(NT, B*H)](
347
+ k=k,
348
+ g=g,
349
+ beta=beta,
350
+ Aw=Aw,
351
+ Au=Au,
352
+ offsets=offsets,
353
+ indices=indices,
354
+ T=T,
355
+ H=H,
356
+ K=K,
357
+ BT=BT,
358
+ BK=BK,
359
+ BC=BC,
360
+ HEAD_FIRST=head_first
361
+ )
362
+ w, u = fwd_recompute_w_u(
363
+ k=k,
364
+ v=v,
365
+ beta=beta,
366
+ Aw=Aw,
367
+ Au=Au,
368
+ offsets=offsets,
369
+ indices=indices,
370
+ head_first=head_first,
371
+ chunk_size=chunk_size
372
+ )
373
+ return w, u, Aw, Au
374
+
375
+
376
+ def fwd_recompute_w_u(
377
+ k: torch.Tensor,
378
+ v: torch.Tensor,
379
+ beta: torch.Tensor,
380
+ Aw: torch.Tensor,
381
+ Au: torch.Tensor,
382
+ offsets: Optional[torch.LongTensor],
383
+ indices: Optional[torch.LongTensor],
384
+ head_first: bool,
385
+ chunk_size: int
386
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
387
+ if head_first:
388
+ B, H, T, K, V = *k.shape, v.shape[-1]
389
+ else:
390
+ B, T, H, K, V = *k.shape, v.shape[-1]
391
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
392
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
393
+ BK = min(triton.next_power_of_2(K), 64)
394
+ BV = min(triton.next_power_of_2(V), 64)
395
+
396
+ u = torch.empty_like(v)
397
+ w = torch.empty_like(k)
398
+ fwd_recompute_w_u_kernel[(NT, B*H)](
399
+ k=k,
400
+ v=v,
401
+ beta=beta,
402
+ w=w,
403
+ u=u,
404
+ Aw=Aw,
405
+ Au=Au,
406
+ offsets=offsets,
407
+ indices=indices,
408
+ T=T,
409
+ H=H,
410
+ K=K,
411
+ V=V,
412
+ BT=BT,
413
+ BK=BK,
414
+ BV=BV,
415
+ HEAD_FIRST=head_first
416
+ )
417
+ return w, u
418
+
419
+
420
+ @triton.heuristics({
421
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
422
+ })
423
+ @triton.autotune(
424
+ configs=[
425
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
426
+ for num_warps in [2, 4]
427
+ for num_stages in [2, 3, 4]
428
+ ],
429
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS']
430
+ )
431
+ @triton.jit(do_not_specialize=['T'])
432
+ def bwd_prepare_wy_repr_kernel(
433
+ k,
434
+ v,
435
+ beta,
436
+ g,
437
+ Aw,
438
+ Au,
439
+ dw,
440
+ du,
441
+ dk,
442
+ dv,
443
+ dbeta,
444
+ dg,
445
+ offsets,
446
+ indices,
447
+ T,
448
+ H: tl.constexpr,
449
+ K: tl.constexpr,
450
+ V: tl.constexpr,
451
+ BT: tl.constexpr,
452
+ BK: tl.constexpr,
453
+ BV: tl.constexpr,
454
+ HEAD_FIRST: tl.constexpr,
455
+ USE_OFFSETS: tl.constexpr
456
+ ):
457
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
458
+ i_b, i_h = i_bh // H, i_bh % H
459
+ if USE_OFFSETS:
460
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
461
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
462
+ T = eos - bos
463
+ else:
464
+ bos, eos = i_b * T, i_b * T + T
465
+
466
+ b_dbeta = tl.zeros([BT], dtype=tl.float32)
467
+ b_dA = tl.zeros([BT, BT], dtype=tl.float32)
468
+ if HEAD_FIRST:
469
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
470
+ p_A = tl.make_block_ptr(Aw + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
471
+ else:
472
+ p_beta = tl.make_block_ptr(beta + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
473
+ p_A = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
474
+
475
+ b_A = tl.load(p_A, boundary_check=(0, 1))
476
+ b_beta = tl.load(p_beta, boundary_check=(0,))
477
+
478
+ for i_k in range(tl.cdiv(K, BK)):
479
+ if HEAD_FIRST:
480
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
481
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
482
+ p_dw = tl.make_block_ptr(dw + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
483
+ else:
484
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
485
+ p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
486
+ p_dw = tl.make_block_ptr(dw + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
487
+ b_k = tl.load(p_k, boundary_check=(0, 1))
488
+ b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
489
+ b_dw = tl.load(p_dw, boundary_check=(0, 1))
490
+ b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False)
491
+ b_dk_beta = tl.dot(b_A, b_dw, allow_tf32=False)
492
+ b_dk = b_dk_beta * b_beta[:, None]
493
+ b_dbeta += tl.sum(b_dk_beta * b_k, 1)
494
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
495
+
496
+ b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0)
497
+ b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
498
+ b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
499
+ b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty)
500
+
501
+ if HEAD_FIRST:
502
+ p_A = tl.make_block_ptr(Au + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
503
+ else:
504
+ p_A = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
505
+ b_A = tl.load(p_A, boundary_check=(0, 1))
506
+ b_dA2 = tl.zeros([BT, BT], dtype=tl.float32)
507
+
508
+ for i_v in range(tl.cdiv(V, BV)):
509
+ if HEAD_FIRST:
510
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
511
+ p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
512
+ p_du = tl.make_block_ptr(du + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
513
+ else:
514
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
515
+ p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
516
+ p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
517
+ b_v = tl.load(p_v, boundary_check=(0, 1))
518
+ b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype)
519
+ b_du = tl.load(p_du, boundary_check=(0, 1))
520
+ b_dA2 += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)
521
+ b_dv_beta = tl.dot(b_A, b_du, allow_tf32=False)
522
+ b_dv = b_dv_beta * b_beta[:, None]
523
+ b_dbeta += tl.sum(b_dv_beta * b_v, 1)
524
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
525
+
526
+ b_dA2 = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA2, 0)
527
+ b_dA2 = tl.dot(b_dA2.to(b_A.dtype), b_A)
528
+ b_dA2 = tl.dot(b_A, b_dA2.to(b_A.dtype))
529
+ b_dA2 = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA2, 0).to(k.dtype.element_ty)
530
+ if HEAD_FIRST:
531
+ p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
532
+ else:
533
+ p_g = tl.make_block_ptr(g + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
534
+ b_g = tl.load(p_g, boundary_check=(0,))
535
+ b_dA2 *= safe_exp(b_g[:, None] - b_g[None, :])
536
+ b_dA += b_dA2
537
+ b_dA = b_dA.to(k.dtype.element_ty)
538
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
539
+
540
+ for i_k in range(tl.cdiv(K, BK)):
541
+ if HEAD_FIRST:
542
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
543
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
544
+ else:
545
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
546
+ p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
547
+ b_k = tl.load(p_k, boundary_check=(0, 1))
548
+ b_dk = tl.load(p_dk, boundary_check=(0, 1))
549
+ b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
550
+ b_A += tl.dot(b_k_beta, tl.trans(b_k))
551
+ b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False)
552
+ b_dbeta += tl.sum(b_dk_beta * b_k, 1)
553
+ b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False)
554
+ b_dk += b_dk_beta * b_beta[:, None]
555
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
556
+ b_dA2 *= b_A
557
+ b_dg = tl.sum(b_dA2, axis=1) - tl.sum(b_dA2, axis=0)
558
+ if HEAD_FIRST:
559
+ p_dg = tl.make_block_ptr(dg + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
560
+ p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
561
+ else:
562
+ p_dg = tl.make_block_ptr(dg + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
563
+ p_dbeta = tl.make_block_ptr(dbeta + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
564
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
565
+ tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,))
566
+
567
+
568
+ def bwd_prepare_wy_repr(
569
+ k: torch.Tensor,
570
+ v: torch.Tensor,
571
+ g: torch.Tensor,
572
+ beta: torch.Tensor,
573
+ Aw: torch.Tensor,
574
+ Au: torch.Tensor,
575
+ dw: torch.Tensor,
576
+ du: torch.Tensor,
577
+ offsets: Optional[torch.LongTensor],
578
+ indices: Optional[torch.LongTensor],
579
+ head_first: bool,
580
+ chunk_size: int
581
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
582
+ if head_first:
583
+ B, H, T, K, V = *k.shape, v.shape[-1]
584
+ else:
585
+ B, T, H, K, V = *k.shape, v.shape[-1]
586
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
587
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
588
+ CONST_TILING = 64 if check_shared_mem() else 32
589
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
590
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
591
+
592
+ dk = torch.empty_like(k)
593
+ dv = torch.empty_like(v)
594
+ dbeta = torch.empty_like(beta)
595
+ dg = torch.empty_like(g)
596
+ bwd_prepare_wy_repr_kernel[(NT, B * H)](
597
+ k=k,
598
+ v=v,
599
+ beta=beta,
600
+ g=g,
601
+ Aw=Aw,
602
+ Au=Au,
603
+ dw=dw,
604
+ du=du,
605
+ dk=dk,
606
+ dv=dv,
607
+ dbeta=dbeta,
608
+ dg=dg,
609
+ offsets=offsets,
610
+ indices=indices,
611
+ T=T,
612
+ H=H,
613
+ K=K,
614
+ V=V,
615
+ BT=BT,
616
+ BK=BK,
617
+ BV=BV,
618
+ HEAD_FIRST=head_first
619
+ )
620
+ return dk, dv, dbeta, dg
fla/ops/gla/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_gla
4
+ from .fused_chunk import fused_chunk_gla
5
+ from .fused_recurrent import fused_recurrent_gla
6
+
7
+ __all__ = [
8
+ 'chunk_gla',
9
+ 'fused_chunk_gla',
10
+ 'fused_recurrent_gla'
11
+ ]
fla/ops/gla/fused_chunk.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Tuple
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import triton
9
+ import triton.language as tl
10
+ from einops import rearrange
11
+ from packaging import version
12
+
13
+ from fla.ops.utils import chunk_local_cumsum
14
+ from fla.ops.utils.op import exp, safe_exp
15
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
16
+
17
+
18
+ @triton.jit(do_not_specialize=['T'])
19
+ def prepare_qg_kg(
20
+ q,
21
+ k,
22
+ g,
23
+ qg,
24
+ kg,
25
+ scale,
26
+ T,
27
+ K: tl.constexpr,
28
+ BT: tl.constexpr,
29
+ BK: tl.constexpr
30
+ ):
31
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
32
+ p_q = q + i_bh * T*K + i_c * BT * K + i_k * BK + tl.arange(0, BK)
33
+ p_g = g + i_bh * T*K + i_c * BT * K + i_k * BK + tl.arange(0, BK)
34
+ p_k = k + i_bh * T*K + i_c * BT * K + i_k * BK + tl.arange(0, BK)
35
+ p_qg = qg + i_bh * T*K + i_c * BT * K + i_k * BK + tl.arange(0, BK)
36
+ p_kg = kg + i_bh * T*K + i_c * BT * K + i_k * BK + tl.arange(0, BK)
37
+
38
+ mask = (i_k * BK + tl.arange(0, BK)) < K
39
+
40
+ last_decay = tl.load(g + i_bh * T*K + (i_c * BT + BT - 1) * K + i_k * BK + tl.arange(0, BK))
41
+
42
+ for _ in range(BT):
43
+ b_q = tl.load(p_q, mask=mask, other=0)
44
+ b_k = tl.load(p_k, mask=mask, other=0)
45
+ b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
46
+ b_q *= exp(b_g) * scale
47
+ b_k *= exp(last_decay - b_g)
48
+ tl.store(p_kg, b_k.to(p_kg.dtype.element_ty), mask=mask)
49
+ tl.store(p_qg, b_q.to(p_qg.dtype.element_ty), mask=mask)
50
+ p_q += K
51
+ p_g += K
52
+ p_k += K
53
+ p_kg += K
54
+ p_qg += K
55
+
56
+
57
+ @triton.jit(do_not_specialize=['T'])
58
+ def bwd_decay_global_cumsum(
59
+ dq_inner,
60
+ dq_inter,
61
+ dk_inner,
62
+ dk_inter,
63
+ q,
64
+ k,
65
+ g,
66
+ dg,
67
+ T,
68
+ K: tl.constexpr,
69
+ BT: tl.constexpr,
70
+ BK: tl.constexpr
71
+ ):
72
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
73
+ p_q = q + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
74
+ p_k = k + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
75
+ p_g = g + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
76
+ p_dg = dg + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
77
+ p_dq_inner = dq_inner + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
78
+ p_dk_inner = dk_inner + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
79
+ p_dq_inter = dq_inter + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
80
+ p_dk_inter = dk_inter + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
81
+ cum_grad_dg = tl.zeros([BK], dtype=tl.float32)
82
+ mask = (i_k * BK + tl.arange(0, BK)) < K
83
+ last_g = tl.zeros([BK], dtype=tl.float32)
84
+ for j in range(BT-1, -1, -1):
85
+ b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
86
+ if j == (BT-1):
87
+ last_g = b_g
88
+ b_dq1 = tl.load(p_dq_inner, mask=mask, other=0)
89
+ b_dq2 = tl.load(p_dq_inter, mask=mask, other=0)
90
+ b_dq2 *= exp(b_g)
91
+ b_dq = b_dq1 + b_dq2
92
+ tl.store(p_dq_inter, b_dq, mask=mask)
93
+ b_dk1 = tl.load(p_dk_inner, mask=mask, other=0)
94
+ b_dk2 = tl.load(p_dk_inter, mask=mask, other=0)
95
+ b_dk2 *= safe_exp(last_g - b_g)
96
+ b_dk = b_dk1 + b_dk2
97
+ tl.store(p_dk_inter, b_dk, mask=mask)
98
+ b_q = tl.load(p_q, mask=mask, other=0)
99
+ b_k = tl.load(p_k, mask=mask, other=0)
100
+ b_dg = b_dq * b_q - b_dk * b_k
101
+ cum_grad_dg += b_dg
102
+ tl.store(p_dg, cum_grad_dg.to(p_dg.dtype.element_ty), mask=mask)
103
+ p_g -= K
104
+ p_k -= K
105
+ p_q -= K
106
+ p_dq_inner -= K
107
+ p_dk_inner -= K
108
+ p_dq_inter -= K
109
+ p_dk_inter -= K
110
+ p_dg -= K
111
+
112
+
113
+ @triton.jit(do_not_specialize=['T'])
114
+ def fused_chunk_gla_fwd_kernel(
115
+ q,
116
+ k,
117
+ v,
118
+ g,
119
+ o,
120
+ h0,
121
+ ht,
122
+ T,
123
+ B: tl.constexpr,
124
+ H: tl.constexpr,
125
+ K: tl.constexpr,
126
+ V: tl.constexpr,
127
+ BT: tl.constexpr,
128
+ BK: tl.constexpr,
129
+ BV: tl.constexpr,
130
+ USE_INITIAL_STATE: tl.constexpr,
131
+ STORE_FINAL_STATE: tl.constexpr,
132
+ CHECK: tl.constexpr
133
+ ):
134
+ # indices
135
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
136
+
137
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
138
+
139
+ # make block pointers
140
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BT, BK), (1, 0))
141
+ p_gn = g + i_bh * T*K + (BT - 1) * K + i_k * BK + tl.arange(0, BK)
142
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BT), (0, 1))
143
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
144
+ p_o = tl.make_block_ptr(o + (i_bh + i_k * B * H) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
145
+
146
+ if USE_INITIAL_STATE:
147
+ p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
148
+ b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
149
+
150
+ mask = (i_k * BK + tl.arange(0, BK)) < K
151
+
152
+ for i in range(0, tl.cdiv(T, BT)):
153
+ # [BK, BT]
154
+ b_k = tl.load(p_k, boundary_check=(0, 1))
155
+ # [BT, BV]
156
+ b_v = tl.load(p_v, boundary_check=(0, 1))
157
+ # [BT, BK]
158
+ b_q = tl.load(p_q, boundary_check=(0, 1))
159
+ b_gn = tl.load(p_gn, mask=mask, other=0).to(tl.float32)
160
+ if CHECK and i == 0:
161
+ b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)
162
+ b_h = b_h * exp(b_gn)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)
163
+ else:
164
+ b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)
165
+ b_h = b_h * exp(b_gn)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)
166
+
167
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
168
+ p_q = tl.advance(p_q, (BT, 0))
169
+ p_k = tl.advance(p_k, (0, BT))
170
+ p_v = tl.advance(p_v, (BT, 0))
171
+ p_o = tl.advance(p_o, (BT, 0))
172
+ p_gn += BT * K
173
+
174
+ if STORE_FINAL_STATE:
175
+ p_final = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
176
+ tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))
177
+
178
+
179
+ # Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
180
+ @triton.jit(do_not_specialize=['T'])
181
+ def fused_chunk_gla_bwd_kernel(
182
+ q, k, v, g,
183
+ do,
184
+ dq,
185
+ dk,
186
+ dv,
187
+ h0,
188
+ scale,
189
+ T,
190
+ B: tl.constexpr,
191
+ H: tl.constexpr,
192
+ K: tl.constexpr,
193
+ V: tl.constexpr,
194
+ # clamp_min, # minimum log value of the gate for numerical stability. default: -5
195
+ BT: tl.constexpr,
196
+ BK: tl.constexpr,
197
+ BV: tl.constexpr,
198
+ USE_INITIAL_STATE: tl.constexpr,
199
+ CHECK: tl.constexpr
200
+ ):
201
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
202
+ # [BV, BK]
203
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
204
+
205
+ if USE_INITIAL_STATE:
206
+ p_h = tl.make_block_ptr(h0 + i_bh * K * V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
207
+ b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
208
+
209
+ mask = (i_k * BK + tl.arange(0, BK)) < K
210
+ for i in range(0, tl.cdiv(T, BT)):
211
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0))
212
+ p_gn = g + i_bh * T*K + ((i+1) * BT - 1) * K + i_k * BK + tl.arange(0, BK)
213
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i * BT), (BV, BT), (0, 1))
214
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i * BT, i_v * BV), (BT, BV), (1, 0))
215
+ p_dq = tl.make_block_ptr(dq + (i_bh+i_v*B*H)*T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0))
216
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
217
+ # [BT, K]
218
+ b_k = tl.load(p_k, boundary_check=(0, 1))
219
+ b_gn = tl.load(p_gn, mask=mask, other=0).to(tl.float32)
220
+
221
+ # [V, BT]
222
+ b_v = tl.load(p_v, boundary_check=(0, 1))
223
+ # [BT, V]
224
+ b_do = tl.load(p_do, boundary_check=(0, 1))
225
+ # [V, K]
226
+ if CHECK and i == 0:
227
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
228
+ b_h = b_h * exp(b_gn)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)
229
+ else:
230
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
231
+ b_h = b_h * exp(b_gn)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)
232
+ b_dq *= scale
233
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
234
+
235
+ # sync threads
236
+ b_h = None
237
+ tl.debug_barrier()
238
+ # [BK, BV]
239
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
240
+
241
+ # cum = tl.zeros([BK], dtype=tl.float32)
242
+ for i in range(1, tl.cdiv(T, BT) + 1):
243
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
244
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
245
+ p_gn = g + i_bh * T*K + (T - (i-1) * BT - 1) * K + i_k * BK + tl.arange(0, BK)
246
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
247
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
248
+ p_dk = tl.make_block_ptr(dk + (i_bh + i_v * B * H) * T*K, (T, K),
249
+ (K, 1), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
250
+ p_dv = tl.make_block_ptr(dv + (i_bh + i_k * B * H) * T*V, (T, V),
251
+ (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
252
+ # [K, BT]
253
+ b_q = tl.load(p_q, boundary_check=(0, 1))
254
+ # [BT, K]
255
+ b_k = tl.load(p_k, boundary_check=(0, 1))
256
+ # [BT, V]
257
+ b_v = tl.load(p_v, boundary_check=(0, 1))
258
+ b_do = tl.load(p_do, boundary_check=(0, 1))
259
+ b_db = tl.load(p_gn, mask=mask, other=0).to(tl.float32)
260
+
261
+ # inter-chunk
262
+ # [K, V]
263
+ if CHECK and i == 1:
264
+ b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))
265
+ b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)
266
+ b_dh = b_dh * exp(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)
267
+ else:
268
+ b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))
269
+ b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)
270
+ b_dh = b_dh * exp(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)
271
+
272
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
273
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
274
+
275
+
276
+ @triton.jit
277
+ def fwd_inner_chunk(
278
+ q, k, g, A,
279
+ scale, # K ** -0.5
280
+ B: tl.constexpr, # B
281
+ H: tl.constexpr, # H
282
+ T, # T
283
+ K: tl.constexpr, # K
284
+ BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
285
+ BK: tl.constexpr # BLOCK SIZE along the K dimension
286
+ ):
287
+
288
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
289
+
290
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
291
+ p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
292
+
293
+ b_k = tl.load(p_k, boundary_check=(0, 1))
294
+ b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
295
+
296
+ mask = (i_k * BK + tl.arange(0, BK)) < K
297
+ o_i = tl.arange(0, BT)
298
+
299
+ p_q = q + i_bh * T*K + i_k * BK + i_t * BT * K + tl.arange(0, BK)
300
+ p_gq = g + i_bh * T*K + i_k * BK + i_t * BT * K + tl.arange(0, BK)
301
+ p_A = A + (i_bh + (i_k * B * H)) * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT)
302
+
303
+ for i in range(BT):
304
+ b_q = tl.load(p_q, mask=mask, other=0) * scale
305
+ b_gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32)
306
+ s = b_q[None, :] * b_k * safe_exp(b_gq[None, :] - b_g)
307
+ score = tl.sum(s, axis=1)
308
+ score = tl.where(o_i <= i, score, 0)
309
+ tl.store(p_A, score.to(p_A.dtype.element_ty))
310
+ p_q += K
311
+ p_gq += K
312
+ p_A += BT
313
+
314
+
315
+ @triton.jit
316
+ def bwd_inner_chunk(
317
+ q,
318
+ k,
319
+ g,
320
+ dA,
321
+ dq,
322
+ dk,
323
+ T, # T
324
+ K: tl.constexpr, # K
325
+ # clamp_min, # minimum log value of the gate for numerical stability. default: -5
326
+ BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
327
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
328
+ ):
329
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
330
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
331
+ b_k = tl.load(p_k, boundary_check=(0, 1))
332
+ p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
333
+ b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
334
+
335
+ mask = (i_k * BK + tl.arange(0, BK)) < K
336
+ o_i = tl.arange(0, BT)
337
+
338
+ p_q = q + i_bh * T*K + i_k * BK + i_t * BT * K + tl.arange(0, BK)
339
+ p_dq = dq + (i_bh) * T*K + i_k * BK + i_t * BT * K + tl.arange(0, BK)
340
+ p_gq = g + i_bh * T*K + i_k * BK + i_t * BT * K + tl.arange(0, BK)
341
+ p_dA = dA + i_bh * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT)
342
+
343
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
344
+
345
+ for i in range(BT):
346
+ b_q = tl.load(p_q, mask=mask, other=0)
347
+ b_gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32)
348
+ score = safe_exp(b_gq[None, :] - b_g)
349
+ score = tl.where(o_i[:, None] <= i, score, 0)
350
+ b_dA = tl.load(p_dA)
351
+ b_dA = tl.where(o_i <= i, b_dA, 0)
352
+ b_dk += (b_dA[:, None] * score * b_q[None, :])
353
+ b_dq = tl.sum(b_dA[:, None] * score * b_k, axis=0)
354
+ tl.store(p_dq, b_dq, mask=mask)
355
+ p_q += K
356
+ p_dq += K
357
+ p_gq += K
358
+ p_dA += BT
359
+
360
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
361
+ tl.store(p_dk, b_dk.to(dk.dtype.element_ty), boundary_check=(0, 1))
362
+
363
+
364
+ class FusedChunkGLAFunction(torch.autograd.Function):
365
+
366
+ @staticmethod
367
+ @input_guard
368
+ @autocast_custom_fwd
369
+ def forward(ctx, q, k, v, g, scale, initial_state, output_final_state):
370
+ ctx.g_dtype = g.dtype
371
+ ctx.scale = scale
372
+ B, H, T, K, V = *k.shape, v.shape[-1]
373
+ BT = 16 # chunk_size
374
+ BK, BV = min(K, 64), min(V, 64)
375
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
376
+ num_stages = 1
377
+ num_warps = 2
378
+
379
+ g_org = g
380
+ # cumulative decay should be in float32, otherwise the err will be accumulated and amplified.
381
+ g = chunk_local_cumsum(g_org, chunk_size=BT)
382
+ o = q.new_empty(NK, B, H, T, V)
383
+ q_g = torch.empty_like(q)
384
+ k_g = torch.empty_like(k)
385
+
386
+ grid = (NK, triton.cdiv(T, BT), B * H)
387
+ prepare_qg_kg[grid](
388
+ q,
389
+ k,
390
+ g,
391
+ q_g,
392
+ k_g,
393
+ scale,
394
+ T=T,
395
+ K=K,
396
+ BT=BT,
397
+ BK=BK,
398
+ num_warps=1
399
+ )
400
+
401
+ if output_final_state:
402
+ final_state = q.new_empty(B, H, K, V, dtype=torch.float, requires_grad=False)
403
+ else:
404
+ final_state = None
405
+ # the bug still exists even for Triton 2.2 on H100 GPUs
406
+ # so we always enable initial checks
407
+ CHECK = True
408
+ if version.parse(triton.__version__) < version.parse('2.2.0'):
409
+ import warnings
410
+ warnings.warn(
411
+ "Triton<2.2.0 detected for running this kernel, "
412
+ "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) "
413
+ "that lead to significant precision loss. "
414
+ "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. "
415
+ "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)."
416
+ )
417
+ CHECK = True
418
+
419
+ grid = (NV, NK, B * H)
420
+ fused_chunk_gla_fwd_kernel[grid](
421
+ q_g, k_g, v, g, o, initial_state, final_state,
422
+ T=T,
423
+ B=B,
424
+ H=H,
425
+ K=K,
426
+ V=V,
427
+ BT=BT,
428
+ BK=BK,
429
+ BV=BV,
430
+ USE_INITIAL_STATE=initial_state is not None,
431
+ STORE_FINAL_STATE=output_final_state,
432
+ CHECK=CHECK,
433
+ num_warps=num_warps,
434
+ num_stages=num_stages
435
+ )
436
+
437
+ o = o.sum(0)
438
+
439
+ # intra-chunk
440
+ chunk_size = 16
441
+ num_chunk = T // chunk_size
442
+ v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk)
443
+ BK = min(K, 64)
444
+ NK = triton.cdiv(K, BK)
445
+ A = q.new_empty(NK, B, H, triton.cdiv(T, BT), BT, BT)
446
+ grid = (NK, triton.cdiv(T, BT), B * H)
447
+ fwd_inner_chunk[grid](
448
+ q, k, g, A,
449
+ scale,
450
+ B=B,
451
+ H=H,
452
+ T=T,
453
+ K=K,
454
+ BT=BT,
455
+ BK=BK,
456
+ num_stages=3,
457
+ num_warps=4
458
+ )
459
+ A = A.sum(0)
460
+ o2 = A @ v2
461
+ o2 = rearrange(o2, 'b h n c d -> b h (n c) d')
462
+ # combine inner and inter
463
+ o.add_(o2)
464
+ ctx.save_for_backward(q, k, v, g_org, A, initial_state)
465
+ ctx.CHECK = CHECK
466
+ return o.to(v), final_state
467
+
468
+ @staticmethod
469
+ @input_guard
470
+ @autocast_custom_bwd
471
+ def backward(ctx, do, dht=None):
472
+ q, k, v, g_org, A, initial_state = ctx.saved_tensors
473
+ B, H, T, K, V = *k.shape, v.shape[-1]
474
+ scale = ctx.scale
475
+
476
+ # recomputation
477
+ # inter-chunk
478
+ BT = 16 # chunk_size
479
+ g = chunk_local_cumsum(g_org, chunk_size=BT)
480
+ BK, BV = min(K, 64), min(V, 64)
481
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
482
+ q_g = torch.empty_like(q)
483
+ k_g = torch.empty_like(k)
484
+ grid = (NK, triton.cdiv(T, BT), B * H)
485
+ prepare_qg_kg[grid](
486
+ q,
487
+ k,
488
+ g,
489
+ q_g,
490
+ k_g,
491
+ scale,
492
+ T=T,
493
+ K=K,
494
+ BT=BT,
495
+ BK=BK,
496
+ num_warps=1
497
+ )
498
+
499
+ BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64)
500
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
501
+ num_stages = 1
502
+ num_warps = 2
503
+ dq = q.new_empty(NV, B, H, T, K)
504
+ dk = q.new_empty(NV, B, H, T, K)
505
+ dv = q.new_empty(NK, B, H, T, V)
506
+
507
+ grid = (NV, NK, B * H)
508
+
509
+ fused_chunk_gla_bwd_kernel[grid](
510
+ q_g,
511
+ k_g,
512
+ v,
513
+ g,
514
+ do,
515
+ dq,
516
+ dk,
517
+ dv,
518
+ initial_state,
519
+ scale,
520
+ T=T,
521
+ B=B,
522
+ H=H,
523
+ K=K,
524
+ V=V,
525
+ BT=BT,
526
+ BK=BK,
527
+ BV=BV,
528
+ USE_INITIAL_STATE=initial_state is not None,
529
+ CHECK=ctx.CHECK,
530
+ num_warps=num_warps,
531
+ num_stages=num_stages,
532
+ )
533
+ dq = dq.sum(0)
534
+ dk = dk.sum(0)
535
+ dv = dv.sum(0)
536
+
537
+ # intra chunk
538
+ NT = T // BT
539
+ v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=NT)
540
+ do2 = rearrange(do, 'b h (n c) d -> b h n c d', n=NT)
541
+ dA2 = (do2 @ v2.transpose(-2, -1)) * scale
542
+ dv2 = A.transpose(-1, -2) @ do2
543
+ dv2 = rearrange(dv2, 'b h n c d -> b h (n c) d', n=NT)
544
+
545
+ BK = min(triton.next_power_of_2(K), 16)
546
+ NK = triton.cdiv(K, BK)
547
+ dk2 = torch.empty_like(k)
548
+ dq2 = torch.empty_like(q)
549
+
550
+ grid = (NK, NT, B * H)
551
+ bwd_inner_chunk[grid](
552
+ q, k, g,
553
+ dA2,
554
+ dq2,
555
+ dk2,
556
+ T=T,
557
+ K=K,
558
+ BT=BT,
559
+ BK=BK,
560
+ num_warps=1,
561
+ num_stages=3
562
+ )
563
+
564
+ BK = min(triton.next_power_of_2(K), 32)
565
+ NK = triton.cdiv(K, BK)
566
+ dg = torch.empty_like(g, dtype=torch.float32)
567
+ grid = (NK, triton.cdiv(T, BT), B * H)
568
+ bwd_decay_global_cumsum[grid](
569
+ dq2,
570
+ dq,
571
+ dk2,
572
+ dk,
573
+ q,
574
+ k,
575
+ g,
576
+ dg,
577
+ T=T,
578
+ K=K,
579
+ BT=BT,
580
+ BK=BK,
581
+ num_warps=1,
582
+ num_stages=1
583
+ )
584
+ dg = rearrange(dg, 'b h (n c) d -> b h n c d', c=BT)
585
+
586
+ def rev_cumsum_exclusive(x):
587
+ cumsum_x = x.cumsum(-2)
588
+ rev_cumsum_x = cumsum_x[..., -1, None, :] - cumsum_x
589
+ return rev_cumsum_x
590
+
591
+ rev_cumsum_dg = rev_cumsum_exclusive(dg[..., 0, :])
592
+ dg.add_(rev_cumsum_dg.unsqueeze(-2))
593
+ dv.add_(dv2)
594
+ dg = rearrange(dg, 'b h n c d -> b h (n c) d')
595
+
596
+ return dq.to(q), dk.to(k), dv.to(v), dg.to(ctx.g_dtype), None, None, None
597
+
598
+
599
+ def ceildiv(a, b):
600
+ return -(a // -b)
601
+
602
+
603
+ def pad(x, chunk_size=16):
604
+ T = x.shape[-2]
605
+ padded_seq_len = ceildiv(T, chunk_size) * chunk_size
606
+ if x.shape[-2] % chunk_size != 0:
607
+ x = F.pad(x, (0, 0, 0, padded_seq_len - T))
608
+ return x
609
+
610
+
611
+ def fused_chunk_gla(
612
+ q: torch.Tensor,
613
+ k: torch.Tensor,
614
+ v: torch.Tensor,
615
+ g: torch.Tensor,
616
+ scale: int = -1,
617
+ initial_state: torch.Tensor = None,
618
+ output_final_state: bool = False,
619
+ head_first: bool = True
620
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
621
+ if scale == -1:
622
+ scale = q.shape[-1] ** -0.5
623
+ if not head_first:
624
+ q, k, v, g = map(lambda x: x.transpose(1, 2), (q, k, v, g))
625
+ seq_len = q.shape[-2]
626
+ q, k, v, g = map(lambda x: pad(x), [q, k, v, g])
627
+ o, final_state = FusedChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state)
628
+ o = o[..., :seq_len, :].contiguous()
629
+ if not head_first:
630
+ o = o.transpose(1, 2)
631
+ return o, final_state
fla/ops/gsa/chunk.py ADDED
@@ -0,0 +1,1264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import reduce
10
+
11
+ from fla.ops.common.chunk_h import chunk_bwd_dh, chunk_fwd_h
12
+ from fla.ops.gla.chunk import chunk_gla_bwd, chunk_gla_fwd
13
+ from fla.ops.utils import chunk_local_cumsum, softmax_bwd, softmax_fwd
14
+ from fla.ops.utils.op import exp, safe_exp
15
+ from fla.utils import input_guard
16
+
17
+
18
+ @triton.heuristics({
19
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
20
+ })
21
+ @triton.autotune(
22
+ configs=[
23
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
24
+ for BK in [32, 64]
25
+ for BV in [32, 64]
26
+ for num_warps in [2, 4, 8]
27
+ for num_stages in [2, 3, 4]
28
+ ],
29
+ key=['BT']
30
+ )
31
+ @triton.jit(do_not_specialize=['T'])
32
+ def chunk_gsa_fwd_k_kernel_inter(
33
+ q,
34
+ k,
35
+ h,
36
+ g,
37
+ o,
38
+ A,
39
+ offsets,
40
+ indices,
41
+ scale,
42
+ T,
43
+ HQ: tl.constexpr,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BK: tl.constexpr,
49
+ BV: tl.constexpr,
50
+ NG: tl.constexpr,
51
+ USE_OFFSETS: tl.constexpr,
52
+ HEAD_FIRST: tl.constexpr
53
+ ):
54
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
55
+ i_bg = i_bh // NG
56
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
57
+ i_h = i_hq // NG
58
+ if USE_OFFSETS:
59
+ i_tg = i_t
60
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
61
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
62
+ T = eos - bos
63
+ NT = tl.cdiv(T, BT)
64
+ else:
65
+ NT = tl.cdiv(T, BT)
66
+ i_tg = i_b * NT + i_t
67
+ bos, eos = i_b * T, i_b * T + T
68
+
69
+ o_i = tl.arange(0, BT)
70
+ m_s = o_i[:, None] >= o_i[None, :]
71
+
72
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
73
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
74
+ for i_k in range(tl.cdiv(K, BK)):
75
+ if HEAD_FIRST:
76
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
77
+ p_k = tl.make_block_ptr(k + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
78
+ p_h = tl.make_block_ptr(h + (i_bg * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
79
+ else:
80
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
81
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
82
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
83
+
84
+ # [BT, BK]
85
+ b_q = tl.load(p_q, boundary_check=(0, 1))
86
+ b_q = (b_q * scale).to(b_q.dtype)
87
+ # [BK, BT]
88
+ b_k = tl.load(p_k, boundary_check=(0, 1))
89
+ # [BK, BV]
90
+ b_h = tl.load(p_h, boundary_check=(0, 1))
91
+ # [BT, BV]
92
+ b_o += tl.dot(b_q, b_h)
93
+ # [BT, BT]
94
+ b_A += tl.dot(b_q, b_k)
95
+ if HEAD_FIRST:
96
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
97
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
98
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
99
+ else:
100
+ p_g = tl.make_block_ptr(g + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
101
+ p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
102
+ p_A = tl.make_block_ptr(A + (bos * HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
103
+ # [BT, BV]
104
+ b_g = tl.load(p_g, boundary_check=(0, 1))
105
+ b_o = b_o * exp(b_g)
106
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
107
+
108
+ # [BT, BT]
109
+ b_A = tl.where(m_s, b_A, 0.)
110
+ if i_v == 0:
111
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
112
+
113
+
114
+ @triton.heuristics({
115
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
116
+ })
117
+ @triton.jit(do_not_specialize=['T'])
118
+ def chunk_gsa_fwd_k_kernel_intra(
119
+ v,
120
+ g,
121
+ o,
122
+ A,
123
+ offsets,
124
+ indices,
125
+ T,
126
+ HQ: tl.constexpr,
127
+ H: tl.constexpr,
128
+ V: tl.constexpr,
129
+ BT: tl.constexpr,
130
+ BC: tl.constexpr,
131
+ BV: tl.constexpr,
132
+ NC: tl.constexpr,
133
+ NG: tl.constexpr,
134
+ USE_OFFSETS: tl.constexpr,
135
+ HEAD_FIRST: tl.constexpr
136
+ ):
137
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
138
+ i_bg = i_bh // NG
139
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
140
+ i_h = i_hq // NG
141
+ i_t, i_i = i_c // NC, i_c % NC
142
+ if USE_OFFSETS:
143
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
144
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
145
+ T = eos - bos
146
+ else:
147
+ bos, eos = i_b * T, i_b * T + T
148
+
149
+ o_v = i_v * BV + tl.arange(0, BV)
150
+ m_v = o_v < V
151
+
152
+ if i_t * BT + i_i * BC > T:
153
+ return
154
+
155
+ if HEAD_FIRST:
156
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
157
+ p_gn = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + min(i_t * BT + i_i * BC, T) * V + o_v, BV), BV)
158
+ else:
159
+ p_g = tl.make_block_ptr(g + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
160
+ p_gn = g + (bos + min(i_t * BT + i_i * BC, T)) * H*V + i_h * V + o_v
161
+ # [BV,]
162
+ b_gn = tl.load(p_gn, mask=m_v, other=0)
163
+ # [BC, BV]
164
+ b_o = tl.zeros([BC, BV], dtype=tl.float32)
165
+ for i_j in range(0, i_i):
166
+ if HEAD_FIRST:
167
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
168
+ p_v = tl.make_block_ptr(v + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
169
+ p_gv = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
170
+ else:
171
+ p_A = tl.make_block_ptr(A + (bos*HQ+i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t*BT+i_i*BC, i_j * BC), (BC, BC), (1, 0))
172
+ p_v = tl.make_block_ptr(v + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
173
+ p_gv = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
174
+ # [BC, BV]
175
+ b_v = tl.load(p_v, boundary_check=(0, 1))
176
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
177
+ b_vg = (b_v * exp(b_gn[None, :] - b_gv)).to(b_v.dtype)
178
+ # [BC, BC]
179
+ b_A = tl.load(p_A, boundary_check=(0, 1))
180
+ b_o += tl.dot(b_A, b_vg)
181
+ # [BC, BV]
182
+ b_g = tl.load(p_g, boundary_check=(0, 1))
183
+ b_o *= exp(b_g - b_gn[None, :])
184
+
185
+ o_i = tl.arange(0, BC)
186
+ if HEAD_FIRST:
187
+ o_A = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
188
+ else:
189
+ o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * HQ*BT + i_hq * BT + i_i * BC
190
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
191
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
192
+ if HEAD_FIRST:
193
+ p_v = tl.max_contiguous(tl.multiple_of(v + i_bg * T*V + (i_t * BT + i_i * BC + j) * V + o_v, BV), BV)
194
+ p_gv = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (i_t * BT + i_i * BC + j) * V + o_v, BV), BV)
195
+ else:
196
+ p_v = v + (bos + i_t * BT + i_i * BC + j) * H*V + i_h * V + o_v
197
+ p_gv = g + (bos + i_t * BT + i_i * BC + j) * H*V + i_h * V + o_v
198
+ # [BC,]
199
+ b_A = tl.load(A + o_A + j, mask=m_A, other=0)
200
+ # [BV,]
201
+ b_v = tl.load(p_v, mask=m_v, other=0).to(tl.float32)
202
+ b_gv = tl.load(p_gv, mask=m_v, other=0).to(tl.float32)
203
+ # [BC, BV]
204
+ b_vg = b_v[None, :] * exp(b_g - b_gv[None, :])
205
+ # avoid 0 * inf = inf
206
+ b_o += tl.where(o_i[:, None] >= j, b_A[:, None] * b_vg, 0.)
207
+ if HEAD_FIRST:
208
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
209
+ else:
210
+ p_o = tl.make_block_ptr(o + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
211
+ b_o += tl.load(p_o, boundary_check=(0, 1))
212
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
213
+
214
+
215
+ @triton.heuristics({
216
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
217
+ })
218
+ @triton.autotune(
219
+ configs=[
220
+ triton.Config({}, num_warps=num_warps)
221
+ for num_warps in [2, 4, 8]
222
+ ],
223
+ key=["BT"]
224
+ )
225
+ @triton.jit(do_not_specialize=['T'])
226
+ def chunk_gsa_bwd_k_kernel_dA(
227
+ v,
228
+ g,
229
+ do,
230
+ dA,
231
+ indices,
232
+ offsets,
233
+ scale,
234
+ T,
235
+ B: tl.constexpr,
236
+ HQ: tl.constexpr,
237
+ H: tl.constexpr,
238
+ V: tl.constexpr,
239
+ BT: tl.constexpr,
240
+ BC: tl.constexpr,
241
+ BV: tl.constexpr,
242
+ NC: tl.constexpr,
243
+ NG: tl.constexpr,
244
+ USE_OFFSETS: tl.constexpr,
245
+ HEAD_FIRST: tl.constexpr
246
+ ):
247
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
248
+ i_bg = i_bh // NG
249
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
250
+ i_h = i_hq // NG
251
+ i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
252
+ if USE_OFFSETS:
253
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
254
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
255
+ all = T
256
+ T = eos - bos
257
+ else:
258
+ bos, eos = i_b * T, i_b * T + T
259
+ all = B * T
260
+
261
+ o_v = i_v * BV + tl.arange(0, BV)
262
+ m_v = o_v < V
263
+
264
+ if i_t * BT + i_i * BC > T:
265
+ return
266
+
267
+ if HEAD_FIRST:
268
+ p_dA = tl.make_block_ptr(dA+(i_v*B*H+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
269
+ else:
270
+ p_dA = tl.make_block_ptr(dA+((i_v*all+bos)*HQ+i_hq)*BT, (T, BT), (HQ*BT, 1), (i_t*BT+i_i*BC, i_j*BC), (BC, BC), (1, 0))
271
+
272
+ # [BC, BC]
273
+ b_dA = tl.zeros([BC, BC], dtype=tl.float32)
274
+ if i_i > i_j:
275
+ if HEAD_FIRST:
276
+ p_v = tl.make_block_ptr(v + i_bg * T*V, (V, T), (1, V), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1))
277
+ p_gv = tl.make_block_ptr(g + i_bg * T*V, (V, T), (1, V), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1))
278
+ p_gn = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (i_t * BT + i_i * BC) * V + o_v, BV), BV)
279
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
280
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
281
+ else:
282
+ p_v = tl.make_block_ptr(v + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1))
283
+ p_gv = tl.make_block_ptr(g + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1))
284
+ p_gn = g + (bos + i_t*BT + i_i*BC) * H*V + i_h * V + o_v
285
+ p_g = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
286
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
287
+ # [BV,]
288
+ b_gn = tl.load(p_gn, mask=m_v, other=0.)
289
+ # [BC, BV]
290
+ b_g = tl.load(p_g, boundary_check=(0, 1))
291
+ b_do = tl.load(p_do, boundary_check=(0, 1))
292
+ b_do = (b_do * exp(b_g - b_gn[None, :]) * scale).to(b_do.dtype)
293
+ # [BV, BC]
294
+ b_v = tl.load(p_v, boundary_check=(0, 1))
295
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
296
+ b_vg = (b_v * exp(b_gn[:, None] - b_gv)).to(b_v.dtype)
297
+ # [BC, BC]
298
+ b_dA = tl.dot(b_do, b_vg)
299
+ elif i_i == i_j:
300
+ if HEAD_FIRST:
301
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
302
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
303
+ p_v = tl.max_contiguous(tl.multiple_of(v + i_bg * T*V + (i_t * BT + i_j * BC) * V + o_v, BV), BV)
304
+ p_gv = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (i_t * BT + i_j * BC) * V + o_v, BV), BV)
305
+ else:
306
+ p_g = tl.make_block_ptr(g + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
307
+ p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
308
+ p_v = v + (bos + i_t*BT + i_j*BC) * H*V + i_h * V + o_v
309
+ p_gv = g + (bos + i_t*BT + i_j*BC) * H*V + i_h * V + o_v
310
+ # [BC, BV]
311
+ b_g = tl.load(p_g, boundary_check=(0, 1))
312
+ b_do = tl.load(p_do, boundary_check=(0, 1)) * scale
313
+ m_v = o_v < V
314
+
315
+ o_i = tl.arange(0, BC)
316
+ # [BC, BC]
317
+ m_dA = o_i[:, None] >= o_i[None, :]
318
+ for j in range(0, min(BC, T - i_t * BT - i_j * BC)):
319
+ # [BV,]
320
+ b_v = tl.load(p_v, mask=m_v, other=0).to(tl.float32)
321
+ b_gv = tl.load(p_gv, mask=m_v, other=0).to(tl.float32)
322
+ # [BC,]
323
+ b_dAj = tl.sum(b_do * b_v[None, :] * exp(b_g - b_gv[None, :]), 1)
324
+ b_dA = tl.where((o_i == j)[None, :], b_dAj[:, None], b_dA)
325
+
326
+ p_v += (1 if HEAD_FIRST else H) * V
327
+ p_gv += (1 if HEAD_FIRST else H) * V
328
+ b_dA = tl.where(m_dA, b_dA, 0.)
329
+ tl.store(p_dA, b_dA.to(dA.dtype.element_ty), boundary_check=(0, 1))
330
+
331
+
332
+ @triton.heuristics({
333
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
334
+ })
335
+ @triton.autotune(
336
+ configs=[
337
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
338
+ for num_warps in [2, 4]
339
+ for num_stages in [2, 3, 4]
340
+ ],
341
+ key=['BT']
342
+ )
343
+ @triton.jit(do_not_specialize=['T'])
344
+ def chunk_gsa_bwd_k_kernel_dqkvg(
345
+ q,
346
+ k,
347
+ v,
348
+ h,
349
+ g,
350
+ A,
351
+ do,
352
+ dh,
353
+ dq,
354
+ dk,
355
+ dv,
356
+ dg,
357
+ dgv,
358
+ dA,
359
+ offsets,
360
+ indices,
361
+ scale,
362
+ T,
363
+ B: tl.constexpr,
364
+ HQ: tl.constexpr,
365
+ H: tl.constexpr,
366
+ K: tl.constexpr,
367
+ V: tl.constexpr,
368
+ BT: tl.constexpr,
369
+ BK: tl.constexpr,
370
+ BV: tl.constexpr,
371
+ NG: tl.constexpr,
372
+ USE_OFFSETS: tl.constexpr,
373
+ HEAD_FIRST: tl.constexpr
374
+ ):
375
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
376
+ i_bg = i_bh // NG
377
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
378
+ i_h = i_hq // NG
379
+ if USE_OFFSETS:
380
+ i_tg = i_t
381
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
382
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
383
+ all = T
384
+ T = eos - bos
385
+ NT = tl.cdiv(T, BT)
386
+ else:
387
+ NT = tl.cdiv(T, BT)
388
+ i_tg = i_b * NT + i_t
389
+ bos, eos = i_b * T, i_b * T + T
390
+ all = B * T
391
+
392
+ o_i = tl.arange(0, BT)
393
+ o_t = min(i_t * BT + BT, T)
394
+ m_s = o_i[:, None] >= o_i[None, :]
395
+
396
+ if HEAD_FIRST:
397
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
398
+ p_k = tl.make_block_ptr(k + i_bg * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
399
+ p_A = tl.make_block_ptr(A + (i_k*B*H+i_bh) * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
400
+ else:
401
+ p_q = tl.make_block_ptr(q + (bos*HQ+i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
402
+ p_k = tl.make_block_ptr(k + (bos*H+i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
403
+ p_A = tl.make_block_ptr(A + ((i_k*all+bos)*HQ+i_hq)*BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
404
+
405
+ # [BT, BK]
406
+ b_q = tl.load(p_q, boundary_check=(0, 1))
407
+ b_k = tl.load(p_k, boundary_check=(0, 1))
408
+ # [BT, BT]
409
+ b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k))
410
+ b_A = tl.where(m_s, b_A, 0.)
411
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
412
+
413
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
414
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
415
+ for i_v in range(tl.cdiv(V, BV)):
416
+ o_v = i_v * BV + tl.arange(0, BV)
417
+ if HEAD_FIRST:
418
+ p_v = tl.make_block_ptr(v + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
419
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
420
+ p_gn = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (o_t - 1) * V + o_v, BV), BV)
421
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
422
+ p_dv = tl.make_block_ptr(dv + (i_k*B*H+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
423
+ p_dg = tl.make_block_ptr(dg + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
424
+ p_dgv = tl.make_block_ptr(dgv + (i_k*B*H+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
425
+ p_h = tl.make_block_ptr(h + i_bg * NT*K*V + i_t * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
426
+ p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
427
+ else:
428
+ p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
429
+ p_g = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
430
+ p_gn = g + (bos + o_t - 1) * H*V + i_h * V + o_v
431
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
432
+ p_dv = tl.make_block_ptr(dv + ((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
433
+ p_dg = tl.make_block_ptr(dg + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
434
+ p_dgv = tl.make_block_ptr(dgv+((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
435
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
436
+ p_dh = tl.make_block_ptr(dh + (i_tg * HQ + i_hq) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
437
+ m_v = o_v < V
438
+
439
+ # [BV,]
440
+ b_gn = tl.load(p_gn, mask=m_v, other=0)
441
+ # [BT, BV]
442
+ b_v = tl.load(p_v, boundary_check=(0, 1))
443
+ b_g = tl.load(p_g, boundary_check=(0, 1))
444
+ b_gv = exp(b_gn[None, :] - b_g)
445
+ # [BV, BK]
446
+ b_h = tl.load(p_h, boundary_check=(0, 1))
447
+ # [BT, BV]
448
+ b_do = tl.load(p_do, boundary_check=(0, 1))
449
+ b_do = (b_do * exp(b_g) * scale).to(b_do.dtype)
450
+ # [BK, BV]
451
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
452
+ # [BV]
453
+ b_dg = tl.sum(tl.trans(b_h) * b_dh, 0) * exp(b_gn)
454
+
455
+ b_dh = b_dh.to(b_k.dtype)
456
+ # [BT, BK]
457
+ b_dq += tl.dot(b_do, b_h.to(b_k.dtype))
458
+ b_dk += tl.dot((b_v * b_gv).to(b_v.dtype), tl.trans(b_dh))
459
+ # [BT, BV]
460
+ b_dv = tl.dot(b_k, b_dh) * b_gv
461
+ # [BV]
462
+ b_dg += tl.sum(b_dv * b_v, 0)
463
+
464
+ if i_k == 0:
465
+ b_dgv = tl.load(p_dg, boundary_check=(0, 1)) + b_dg[None, :]
466
+ else:
467
+ b_dgv = tl.zeros([BT, BV], dtype=tl.float32) + b_dg[None, :]
468
+
469
+ tl.store(p_dgv, b_dgv.to(p_dgv.dtype.element_ty), boundary_check=(0, 1))
470
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
471
+ if HEAD_FIRST:
472
+ p_dA = tl.make_block_ptr(dA + i_bh * T*BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
473
+ p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
474
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
475
+ else:
476
+ p_dA = tl.make_block_ptr(dA + (bos*HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
477
+ p_dq = tl.make_block_ptr(dq + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
478
+ p_dk = tl.make_block_ptr(dk + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
479
+ # [BT, BT]
480
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
481
+ # [BT, BK]
482
+ b_dq += tl.dot(b_dA, b_k)
483
+ b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q)
484
+
485
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
486
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
487
+
488
+
489
+ @triton.heuristics({
490
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
491
+ })
492
+ @triton.jit(do_not_specialize=['T'])
493
+ def chunk_gsa_bwd_k_kernel_intra_dvg(
494
+ v,
495
+ g,
496
+ o,
497
+ A,
498
+ do,
499
+ dv,
500
+ dg,
501
+ offsets,
502
+ indices,
503
+ T,
504
+ HQ: tl.constexpr,
505
+ H: tl.constexpr,
506
+ V: tl.constexpr,
507
+ BT: tl.constexpr,
508
+ BC: tl.constexpr,
509
+ BV: tl.constexpr,
510
+ NC: tl.constexpr,
511
+ NG: tl.constexpr,
512
+ USE_OFFSETS: tl.constexpr,
513
+ HEAD_FIRST: tl.constexpr
514
+ ):
515
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
516
+ i_bg = i_bh // NG
517
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
518
+ i_h = i_hq // NG
519
+ i_t, i_i = i_c // NC, i_c % NC
520
+ if USE_OFFSETS:
521
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
522
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
523
+ T = eos - bos
524
+ else:
525
+ bos, eos = i_b * T, i_b * T + T
526
+
527
+ o_v = i_v * BV + tl.arange(0, BV)
528
+ m_v = o_v < V
529
+
530
+ if i_t * BT + i_i * BC > T:
531
+ return
532
+
533
+ if HEAD_FIRST:
534
+ p_gv = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
535
+ p_gn = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (min(i_t * BT + i_i * BC + BC, T) - 1) * V + o_v, BV), BV)
536
+ else:
537
+ p_gv = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
538
+ p_gn = g + (bos + min(i_t * BT + i_i * BC + BC, T)-1)*H*V + i_h*V + o_v
539
+ # [BV,]
540
+ b_gn = tl.load(p_gn, mask=m_v, other=0)
541
+ # [BC, BV]
542
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
543
+ b_dv = tl.zeros([BC, BV], dtype=tl.float32)
544
+ for i_j in range(i_i + 1, NC):
545
+ if HEAD_FIRST:
546
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
547
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (BT, T), (1, BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1))
548
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
549
+ else:
550
+ p_g = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
551
+ p_A = tl.make_block_ptr(A + (bos*HQ+i_hq) * BT, (BT, T), (1, HQ*BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
552
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_j*BC, i_v*BV), (BC, BV), (1, 0))
553
+ # [BC, BV]
554
+ b_g = tl.load(p_g, boundary_check=(0, 1))
555
+ b_do = tl.load(p_do, boundary_check=(0, 1)) * safe_exp(b_g - b_gn[None, :])
556
+ # [BC, BC]
557
+ b_A = tl.load(p_A, boundary_check=(0, 1))
558
+ # [BC, BV]
559
+ b_dv += tl.dot(b_A, b_do.to(b_A.dtype))
560
+ b_dv *= exp(b_gn[None, :] - b_gv)
561
+
562
+ o_i = tl.arange(0, BC)
563
+ o_c = i_i * BC + tl.arange(0, BC)
564
+
565
+ if HEAD_FIRST:
566
+ p_g = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (i_t * BT + i_i * BC) * V + o_v, BV), BV)
567
+ p_A = tl.max_contiguous(tl.multiple_of(A + i_bh * T*BT + (i_t * BT + i_i * BC) * BT + o_c, BC), BC)
568
+ p_do = tl.max_contiguous(tl.multiple_of(do + i_bh * T*V + (i_t * BT + i_i * BC) * V + o_v, BV), BV)
569
+ else:
570
+ p_g = g + (bos + i_t * BT + i_i * BC) * H*V + i_h * V + o_v
571
+ p_A = A + (bos + i_t*BT + i_i*BC) * HQ*BT + i_hq * BT + o_c
572
+ p_do = do + (bos + i_t*BT + i_i*BC) * HQ*V + i_hq * V + o_v
573
+
574
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
575
+ # [BC,]
576
+ b_A = tl.load(p_A)
577
+ # [BV,]
578
+ b_g = tl.load(p_g, mask=m_v, other=0)
579
+ b_do = tl.load(p_do, mask=m_v, other=0)
580
+ # [BC, BV]
581
+ m_i = o_i[:, None] <= j
582
+ b_dv += tl.where(m_i, exp(b_g[None, :] - b_gv) * b_A[:, None] * b_do[None, :], 0.)
583
+
584
+ p_g += (1 if HEAD_FIRST else H) * V
585
+ p_A += (1 if HEAD_FIRST else HQ) * BT
586
+ p_do += (1 if HEAD_FIRST else HQ) * V
587
+ if HEAD_FIRST:
588
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
589
+ p_v = tl.make_block_ptr(v + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
590
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
591
+ p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
592
+ p_dg = tl.make_block_ptr(dg + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
593
+ else:
594
+ p_o = tl.make_block_ptr(o + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
595
+ p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
596
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
597
+ p_dv = tl.make_block_ptr(dv + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
598
+ p_dg = tl.make_block_ptr(dg + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
599
+
600
+ b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
601
+ b_v = tl.load(p_v, boundary_check=(0, 1)).to(tl.float32)
602
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(tl.float32)
603
+ b_dv = b_dv + tl.load(p_dv, boundary_check=(0, 1)).to(tl.float32)
604
+ b_dg = b_o * b_do - b_v * b_dv
605
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
606
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
607
+
608
+
609
+ def chunk_gsa_fwd_v(
610
+ q: torch.Tensor,
611
+ k: torch.Tensor,
612
+ v: torch.Tensor,
613
+ g: torch.Tensor,
614
+ scale: float = 1.,
615
+ initial_state: Optional[torch.Tensor] = None,
616
+ output_final_state: bool = False,
617
+ offsets: Optional[torch.LongTensor] = None,
618
+ indices: Optional[torch.LongTensor] = None,
619
+ head_first: bool = True,
620
+ chunk_size: int = 64
621
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
622
+ _, A, h, ht, o = chunk_gla_fwd(
623
+ q=q,
624
+ k=k,
625
+ v=v,
626
+ g=None,
627
+ g_cumsum=g,
628
+ scale=scale,
629
+ initial_state=initial_state,
630
+ output_final_state=output_final_state,
631
+ offsets=offsets,
632
+ indices=indices,
633
+ head_first=head_first,
634
+ chunk_size=chunk_size
635
+ )
636
+ return A, h, ht, o
637
+
638
+
639
+ def chunk_gsa_fwd_k(
640
+ q: torch.Tensor,
641
+ k: torch.Tensor,
642
+ v: torch.Tensor,
643
+ g: torch.Tensor,
644
+ h0: Optional[torch.Tensor] = None,
645
+ output_final_state: bool = False,
646
+ scale: float = 1.,
647
+ offsets: Optional[torch.LongTensor] = None,
648
+ indices: Optional[torch.LongTensor] = None,
649
+ head_first: bool = True,
650
+ chunk_size: int = 64
651
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
652
+ if head_first:
653
+ B, H, T, K, V = *k.shape, v.shape[-1]
654
+ else:
655
+ B, T, H, K, V = *k.shape, v.shape[-1]
656
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
657
+ BC = min(16, BT)
658
+ BV = min(64, triton.next_power_of_2(V))
659
+ HQ = q.shape[1] if head_first else q.shape[2]
660
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
661
+ NC = triton.cdiv(BT, BC)
662
+ NG = HQ // H
663
+
664
+ h, ht = chunk_fwd_h(
665
+ k=k,
666
+ v=v,
667
+ g=None,
668
+ gk=None,
669
+ gv=g,
670
+ h0=h0,
671
+ output_final_state=output_final_state,
672
+ offsets=offsets,
673
+ head_first=head_first,
674
+ chunk_size=BT,
675
+ states_in_fp32=False
676
+ )
677
+ o = v.new_empty(B, *((HQ, T) if head_first else (T, HQ)), V)
678
+ A = q.new_empty(B, *((HQ, T) if head_first else (T, HQ)), BT)
679
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * HQ)
680
+ chunk_gsa_fwd_k_kernel_inter[grid](
681
+ q,
682
+ k,
683
+ h,
684
+ g,
685
+ o,
686
+ A,
687
+ offsets=offsets,
688
+ indices=indices,
689
+ scale=scale,
690
+ T=T,
691
+ HQ=HQ,
692
+ H=H,
693
+ K=K,
694
+ V=V,
695
+ BT=BT,
696
+ NG=NG,
697
+ HEAD_FIRST=head_first
698
+ )
699
+
700
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT * NC, B * HQ)
701
+ chunk_gsa_fwd_k_kernel_intra[grid](
702
+ v,
703
+ g,
704
+ o,
705
+ A,
706
+ offsets=offsets,
707
+ indices=indices,
708
+ T=T,
709
+ HQ=HQ,
710
+ H=H,
711
+ V=V,
712
+ BT=BT,
713
+ BC=BC,
714
+ BV=BV,
715
+ NC=NC,
716
+ NG=NG,
717
+ HEAD_FIRST=head_first,
718
+ num_warps=4,
719
+ num_stages=2
720
+ )
721
+ return A, h, ht, o
722
+
723
+
724
+ def chunk_gsa_bwd_v(
725
+ q: torch.Tensor,
726
+ k: torch.Tensor,
727
+ v: torch.Tensor,
728
+ g: torch.Tensor,
729
+ h0: torch.Tensor,
730
+ h: torch.Tensor,
731
+ A: torch.Tensor,
732
+ do: torch.Tensor,
733
+ dht: torch.Tensor,
734
+ dg: torch.Tensor,
735
+ scale: float = 1.,
736
+ offsets: Optional[torch.LongTensor] = None,
737
+ indices: Optional[torch.LongTensor] = None,
738
+ head_first: bool = True,
739
+ chunk_size: int = 64
740
+ ):
741
+ dq, dk, dv, dg, dh0 = chunk_gla_bwd(
742
+ q=q,
743
+ k=k,
744
+ v=v,
745
+ g=None,
746
+ g_cumsum=g,
747
+ scale=scale,
748
+ initial_state=h0,
749
+ h=h,
750
+ A=A,
751
+ do=do,
752
+ dht=dht,
753
+ offsets=offsets,
754
+ indices=indices,
755
+ head_first=head_first,
756
+ chunk_size=chunk_size
757
+ )
758
+ return dq, dk, dv, dg, dh0
759
+
760
+
761
+ def chunk_gsa_bwd_k(
762
+ q: torch.Tensor,
763
+ k: torch.Tensor,
764
+ v: torch.Tensor,
765
+ g: torch.Tensor,
766
+ h: torch.Tensor,
767
+ h0: torch.Tensor,
768
+ o: torch.Tensor,
769
+ do: torch.Tensor,
770
+ dht: torch.Tensor,
771
+ dg: torch.Tensor,
772
+ scale: float = 1.,
773
+ offsets: Optional[torch.LongTensor] = None,
774
+ indices: Optional[torch.LongTensor] = None,
775
+ head_first: bool = True,
776
+ chunk_size: int = 64
777
+ ):
778
+ if head_first:
779
+ B, H, T, K, V = *k.shape, v.shape[-1]
780
+ else:
781
+ B, T, H, K, V = *k.shape, v.shape[-1]
782
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
783
+ BC = min(16, BT)
784
+ BK = min(64, triton.next_power_of_2(K))
785
+ BV = min(64, triton.next_power_of_2(V))
786
+ HQ = q.shape[1] if head_first else q.shape[2]
787
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
788
+ NC = triton.cdiv(BT, BC)
789
+ NK = triton.cdiv(K, BK)
790
+ NV = triton.cdiv(V, BV)
791
+ NG = HQ // H
792
+
793
+ if h is None:
794
+ h, _ = chunk_fwd_h(
795
+ k=k,
796
+ v=v,
797
+ g=None,
798
+ gk=None,
799
+ gv=g,
800
+ h0=h0,
801
+ output_final_state=False,
802
+ offsets=offsets,
803
+ head_first=head_first,
804
+ chunk_size=BT,
805
+ states_in_fp32=False
806
+ )
807
+ dh, dh0 = chunk_bwd_dh(
808
+ q=q,
809
+ k=k,
810
+ v=v,
811
+ g=None,
812
+ gk=None,
813
+ gv=g,
814
+ do=do,
815
+ h0=h0,
816
+ dht=dht,
817
+ scale=scale,
818
+ offsets=offsets,
819
+ head_first=head_first,
820
+ chunk_size=BT,
821
+ states_in_fp32=True
822
+ )
823
+ dA = q.new_empty(NV, B, *((HQ, T) if head_first else (T, HQ)), BT)
824
+ grid = (NV, NT * NC * NC, B * HQ)
825
+ chunk_gsa_bwd_k_kernel_dA[grid](
826
+ v,
827
+ g,
828
+ do,
829
+ dA,
830
+ offsets=offsets,
831
+ indices=indices,
832
+ scale=scale,
833
+ T=T,
834
+ B=B,
835
+ HQ=HQ,
836
+ H=H,
837
+ V=V,
838
+ BT=BT,
839
+ BC=BC,
840
+ BV=BV,
841
+ NC=NC,
842
+ NG=NG,
843
+ HEAD_FIRST=head_first
844
+ )
845
+ dA = dA.sum(0, dtype=dA.dtype)
846
+
847
+ A = do.new_empty(NK, B, *((HQ, T) if head_first else (T, HQ)), BT)
848
+ dq = torch.empty_like(q)
849
+ dk = k.new_empty(B, *((HQ, T) if head_first else (T, HQ)), K)
850
+ dv = v.new_empty(NK, B, *((HQ, T) if head_first else (T, HQ)), V)
851
+ dgv = g.new_empty(NK, B, *((HQ, T) if head_first else (T, HQ)), V, dtype=torch.float)
852
+ grid = (NK, NT, B * HQ)
853
+ chunk_gsa_bwd_k_kernel_dqkvg[grid](
854
+ q,
855
+ k,
856
+ v,
857
+ h,
858
+ g,
859
+ A,
860
+ do,
861
+ dh,
862
+ dq,
863
+ dk,
864
+ dv,
865
+ dg,
866
+ dgv,
867
+ dA,
868
+ offsets=offsets,
869
+ indices=indices,
870
+ scale=scale,
871
+ T=T,
872
+ B=B,
873
+ HQ=HQ,
874
+ H=H,
875
+ K=K,
876
+ V=V,
877
+ BT=BT,
878
+ BK=BK,
879
+ BV=BV,
880
+ NG=NG,
881
+ HEAD_FIRST=head_first
882
+ )
883
+ A = A.sum(0, dtype=A.dtype)
884
+ dv = dv.sum(0, dtype=dv.dtype)
885
+ dgv = dgv.sum(0, dtype=dgv.dtype)
886
+
887
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT * NC, B * HQ)
888
+ chunk_gsa_bwd_k_kernel_intra_dvg[grid](
889
+ v,
890
+ g,
891
+ o,
892
+ A,
893
+ do,
894
+ dv,
895
+ dg,
896
+ offsets=offsets,
897
+ indices=indices,
898
+ T=T,
899
+ HQ=HQ,
900
+ H=H,
901
+ V=V,
902
+ BT=BT,
903
+ BC=BC,
904
+ BV=BV,
905
+ NC=NC,
906
+ NG=NG,
907
+ HEAD_FIRST=head_first,
908
+ num_warps=4,
909
+ num_stages=2
910
+ )
911
+ dg = dgv.add_(chunk_local_cumsum(dg, chunk_size=BT, reverse=True, offsets=offsets, indices=indices, head_first=head_first))
912
+
913
+ return dq, dk, dv, dg, dh0
914
+
915
+
916
+ def chunk_gsa_fwd(
917
+ q: torch.Tensor,
918
+ k: torch.Tensor,
919
+ v: torch.Tensor,
920
+ s: torch.Tensor,
921
+ g: torch.Tensor,
922
+ initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
923
+ output_final_state: bool = False,
924
+ scale: float = 1.,
925
+ offsets: Optional[torch.LongTensor] = None,
926
+ indices: Optional[torch.LongTensor] = None,
927
+ head_first: bool = True,
928
+ chunk_size: int = 64
929
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
930
+ hk0, hv0 = None, None
931
+ if initial_state is not None:
932
+ hk0, hv0 = initial_state
933
+ Ak, hk, hkt, ok = chunk_gsa_fwd_k(
934
+ q=q,
935
+ k=k,
936
+ v=s,
937
+ g=g,
938
+ h0=hk0,
939
+ output_final_state=output_final_state,
940
+ scale=scale,
941
+ offsets=offsets,
942
+ indices=indices,
943
+ head_first=head_first,
944
+ chunk_size=chunk_size
945
+ )
946
+
947
+ # p is kept in fp32 for safe softmax backward
948
+ p = softmax_fwd(ok, dtype=torch.float)
949
+
950
+ qv = p.to(q.dtype)
951
+ Av, hv, hvt, ov = chunk_gsa_fwd_v(
952
+ q=qv,
953
+ k=s,
954
+ v=v,
955
+ g=g,
956
+ scale=1.,
957
+ initial_state=hv0,
958
+ output_final_state=output_final_state,
959
+ offsets=offsets,
960
+ indices=indices,
961
+ head_first=head_first,
962
+ chunk_size=chunk_size
963
+ )
964
+ return Ak, hk, hkt, ok, p, Av, hv, hvt, ov
965
+
966
+
967
+ def chunk_gsa_bwd(
968
+ q: torch.Tensor,
969
+ k: torch.Tensor,
970
+ v: torch.Tensor,
971
+ s: torch.Tensor,
972
+ g: torch.Tensor,
973
+ ok: torch.Tensor,
974
+ p: torch.Tensor,
975
+ A: Tuple[torch.Tensor, torch.Tensor],
976
+ h: Tuple[torch.Tensor, torch.Tensor],
977
+ initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]],
978
+ scale: float,
979
+ do: torch.Tensor,
980
+ dht: Tuple[torch.Tensor, torch.Tensor],
981
+ offsets: Optional[torch.LongTensor] = None,
982
+ indices: Optional[torch.LongTensor] = None,
983
+ head_first: bool = True,
984
+ chunk_size: int = 64
985
+ ):
986
+ hk0, hv0 = None, None
987
+ if initial_state is not None:
988
+ hk0, hv0 = initial_state
989
+
990
+ _, Av = A
991
+ hk, hv = h
992
+ dhkt, dhvt = dht
993
+
994
+ qv = p.to(q.dtype)
995
+ dqv, dsv, dv, dg, dhv0 = chunk_gsa_bwd_v(
996
+ q=qv,
997
+ k=s,
998
+ v=v,
999
+ g=g,
1000
+ h0=hv0,
1001
+ h=hv,
1002
+ A=Av,
1003
+ do=do,
1004
+ dht=dhvt,
1005
+ dg=None,
1006
+ scale=1.,
1007
+ offsets=offsets,
1008
+ indices=indices,
1009
+ head_first=head_first,
1010
+ chunk_size=chunk_size
1011
+ )
1012
+
1013
+ # softmax gradient, equivalent to:
1014
+ # dok = qv * (dqv - (qv * dqv).sum(-1, True))
1015
+ dok = softmax_bwd(p, dqv, dtype=ok.dtype)
1016
+
1017
+ dq, dk, dsk, dg, dhk0 = chunk_gsa_bwd_k(
1018
+ q=q,
1019
+ k=k,
1020
+ v=s,
1021
+ g=g,
1022
+ h0=hk0,
1023
+ h=hk,
1024
+ o=ok,
1025
+ do=dok,
1026
+ dht=dhkt,
1027
+ dg=dg,
1028
+ scale=scale,
1029
+ offsets=offsets,
1030
+ indices=indices,
1031
+ head_first=head_first,
1032
+ chunk_size=chunk_size
1033
+ )
1034
+
1035
+ ds = dsv.add_(dsk)
1036
+ if q.shape[1] != k.shape[1]:
1037
+ dk, dv, ds, dg = map(lambda x: reduce(x, 'b (h g) ... -> b h ...', 'sum', h=k.shape[1]), (dk, dv, ds, dg))
1038
+ dg = dg.to(s.dtype)
1039
+ return dq, dk, dv, ds, dg, dhk0, dhv0
1040
+
1041
+
1042
+ class ChunkGSAFunction(torch.autograd.Function):
1043
+
1044
+ @staticmethod
1045
+ @input_guard
1046
+ def forward(
1047
+ ctx,
1048
+ q: torch.Tensor,
1049
+ k: torch.Tensor,
1050
+ v: torch.Tensor,
1051
+ s: torch.Tensor,
1052
+ g: torch.Tensor,
1053
+ scale: float,
1054
+ hk0: Optional[torch.Tensor],
1055
+ hv0: Optional[torch.Tensor],
1056
+ output_final_state: bool,
1057
+ checkpoint_level: int,
1058
+ offsets: Optional[torch.LongTensor],
1059
+ head_first: bool = True
1060
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1061
+ T = q.shape[2] if head_first else q.shape[1]
1062
+ chunk_size = min(64, max(16, triton.next_power_of_2(T)))
1063
+
1064
+ # 2-d indices denoting the offsets of chunks in each sequence
1065
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
1066
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
1067
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
1068
+ indices = None
1069
+ if offsets is not None:
1070
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()])
1071
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
1072
+ g_org, g = g, chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first)
1073
+ Ak, hk, hkt, ok, p, Av, hv, hvt, ov = chunk_gsa_fwd(
1074
+ q=q,
1075
+ k=k,
1076
+ v=v,
1077
+ s=s,
1078
+ g=g,
1079
+ initial_state=(hk0, hv0),
1080
+ output_final_state=output_final_state,
1081
+ scale=scale,
1082
+ offsets=offsets,
1083
+ indices=indices,
1084
+ head_first=head_first,
1085
+ chunk_size=chunk_size
1086
+ )
1087
+
1088
+ if checkpoint_level >= 1:
1089
+ del g
1090
+ g = g_org
1091
+ if checkpoint_level > 1:
1092
+ del hk
1093
+ del hv
1094
+ hk, hv = None, None
1095
+ else:
1096
+ hk0, hv0 = None, None
1097
+
1098
+ ctx.save_for_backward(q, k, v, s, g, ok, p, Av, hk0, hv0, hk, hv)
1099
+ ctx.checkpoint_level = checkpoint_level
1100
+ ctx.scale = scale
1101
+ ctx.offsets = offsets
1102
+ ctx.indices = indices
1103
+ ctx.head_first = head_first
1104
+ ctx.chunk_size = chunk_size
1105
+ return ov, hkt, hvt
1106
+
1107
+ @staticmethod
1108
+ @input_guard
1109
+ def backward(ctx, dov, dhkt=None, dhvt=None):
1110
+ q, k, v, s, g, ok, p, Av, hk0, hv0, hk, hv = ctx.saved_tensors
1111
+ scale = ctx.scale
1112
+ offsets = ctx.offsets
1113
+ indices = ctx.indices
1114
+ head_first = ctx.head_first
1115
+ chunk_size = ctx.chunk_size
1116
+
1117
+ if ctx.checkpoint_level >= 1:
1118
+ g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first)
1119
+ dq, dk, dv, ds, dg, dhk0, dhv0 = chunk_gsa_bwd(
1120
+ q=q,
1121
+ k=k,
1122
+ v=v,
1123
+ s=s,
1124
+ g=g,
1125
+ ok=ok,
1126
+ p=p,
1127
+ A=(None, Av),
1128
+ h=(hk, hv),
1129
+ initial_state=(hk0, hv0),
1130
+ scale=scale,
1131
+ do=dov,
1132
+ dht=(dhkt, dhvt),
1133
+ offsets=offsets,
1134
+ indices=indices,
1135
+ head_first=head_first,
1136
+ chunk_size=chunk_size
1137
+ )
1138
+ return dq, dk, dv, ds, dg, None, dhk0, dhv0, None, None, None, None
1139
+
1140
+
1141
+ @torch.compiler.disable
1142
+ def chunk_gsa(
1143
+ q: torch.Tensor,
1144
+ k: torch.Tensor,
1145
+ v: torch.Tensor,
1146
+ s: torch.Tensor,
1147
+ g: Optional[torch.Tensor] = None,
1148
+ scale: Optional[int] = None,
1149
+ initial_state: Optional[Tuple[torch.Tensor]] = None,
1150
+ output_final_state: Optional[bool] = False,
1151
+ checkpoint_level: Optional[int] = 2,
1152
+ cu_seqlens: Optional[torch.LongTensor] = None,
1153
+ head_first: Optional[bool] = True
1154
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1155
+ r"""
1156
+ Args:
1157
+ q (torch.Tensor):
1158
+ queries of shape `[B, HQ, T, K]` if `head_first=True` else `[B, T, HQ, K]`.
1159
+ k (torch.Tensor):
1160
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
1161
+ GQA is performed if `H` is not equal to `HQ`.
1162
+ v (torch.Tensor):
1163
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
1164
+ s (torch.Tensor):
1165
+ slot representations of shape `[B, H, T, M]` if `head_first=True` else `[B, T, H, M]`.
1166
+ g (torch.Tensor):
1167
+ Forget gates of shape `[B, H, T, M]` applied to keys.
1168
+ If not provided, this function is equivalent to vanilla ABC.
1169
+ scale (Optional[int]):
1170
+ Scale factor for attention scores.
1171
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
1172
+ initial_state (Optional[Tuple[torch.Tensor]]):
1173
+ Initial state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` for `N` input sequences.
1174
+ For equal-length input sequences, `N` equals the batch size `B`.
1175
+ Default: `None`.
1176
+ output_final_state (Optional[bool]):
1177
+ Whether to output the final state tuple, having tensors of shape `[N, H, K, M]` and `[N, H, M, V]`.
1178
+ Default: `False`.
1179
+ checkpoint_level (Optional[int]):
1180
+ Checkpointing level; higher values will save more memories and do more recomputations during backward.
1181
+ Default: `2`:
1182
+ - Level `0`: no memory saved, no recomputation.
1183
+ - Level `1`: recompute the fp32 cumulative values during backward.
1184
+ - Level `2`: recompute the fp32 cumulative values and forward hidden states during backward.
1185
+ cu_seqlens (torch.LongTensor):
1186
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
1187
+ consistent with the FlashAttention API.
1188
+ head_first (Optional[bool]):
1189
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
1190
+ Default: `True`.
1191
+
1192
+ Returns:
1193
+ o (torch.Tensor):
1194
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
1195
+ final_state (Tuple[torch.Tensor]):
1196
+ Final state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` if `output_final_state=True`.
1197
+ `None` otherwise.
1198
+
1199
+ Examples::
1200
+ >>> import torch
1201
+ >>> import torch.nn.functional as F
1202
+ >>> from einops import rearrange
1203
+ >>> from fla.ops.gsa import fused_recurrent_gsa
1204
+ # inputs with equal lengths
1205
+ >>> B, T, H, K, V, M = 4, 2048, 4, 512, 512, 64
1206
+ >>> q = torch.randn(B, T, H, K, device='cuda')
1207
+ >>> k = torch.randn(B, T, H, K, device='cuda')
1208
+ >>> v = torch.randn(B, T, H, V, device='cuda')
1209
+ >>> s = torch.randn(B, T, H, M, device='cuda')
1210
+ >>> g = F.logsigmoid(torch.randn(B, T, H, M, device='cuda'))
1211
+ >>> h0 = (torch.randn(B, H, K, M, device='cuda'), torch.randn(B, H, M, V, device='cuda'))
1212
+ >>> o, (hk, hv) = chunk_gsa(q, k, v, s, g,
1213
+ initial_state=h0,
1214
+ output_final_state=True,
1215
+ head_first=False)
1216
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
1217
+ >>> q, k, v, s, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, s, g))
1218
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
1219
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
1220
+ >>> o_var, (hk_var, hv_var) = chunk_gsa(q, k, v, s, g,
1221
+ initial_state=h0,
1222
+ output_final_state=True,
1223
+ cu_seqlens=cu_seqlens,
1224
+ head_first=False)
1225
+ >>> assert o.allclose(o_var.view(o.shape))
1226
+ >>> assert hk.allclose(hk_var)
1227
+ >>> assert hv.allclose(hv_var)
1228
+ """
1229
+ if cu_seqlens is not None:
1230
+ if q.shape[0] != 1:
1231
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
1232
+ f"Please flatten variable-length inputs before processing.")
1233
+ if head_first:
1234
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
1235
+ if initial_state is not None and initial_state[0].shape[0] != len(cu_seqlens) - 1:
1236
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
1237
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state[0].shape[0]}.")
1238
+ assert checkpoint_level in [0, 1, 2]
1239
+ if g is None:
1240
+ # TODO: this 3 steps took huge amount of time, ought to be optimized
1241
+ z = s.float().logcumsumexp(2)
1242
+ g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
1243
+ s = torch.exp(s - z).to(k.dtype)
1244
+ if scale is None:
1245
+ scale = q.shape[-1] ** -0.5
1246
+
1247
+ hk0, hv0 = None, None
1248
+ if initial_state is not None:
1249
+ hk0, hv0 = initial_state
1250
+ o, *final_state = ChunkGSAFunction.apply(
1251
+ q,
1252
+ k,
1253
+ v,
1254
+ s,
1255
+ g,
1256
+ scale,
1257
+ hk0,
1258
+ hv0,
1259
+ output_final_state,
1260
+ checkpoint_level,
1261
+ cu_seqlens,
1262
+ head_first
1263
+ )
1264
+ return o, final_state
fla/ops/gsa/naive.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from einops import repeat
7
+
8
+
9
+ def naive_recurrent_gsa(
10
+ q: torch.Tensor,
11
+ k: torch.Tensor,
12
+ v: torch.Tensor,
13
+ s: torch.Tensor,
14
+ g: Optional[torch.Tensor] = None,
15
+ scale: Optional[int] = None,
16
+ initial_state: Optional[torch.Tensor] = None,
17
+ output_final_state: Optional[bool] = False
18
+ ) -> torch.Tensor:
19
+ dtype = q.dtype
20
+
21
+ NG = q.shape[1]//k.shape[1]
22
+ # [batch_size, n_heads, seq_len, n_slots]
23
+ if g is None:
24
+ z = s.float().logcumsumexp(2)
25
+ g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
26
+ s = torch.exp(s - z)
27
+ q, k, v, s, g = map(lambda x: x.float(), (q, k, v, s, g))
28
+ k, v, s, g = map(lambda x: repeat(x, 'b h t d -> b (h g) t d', g=NG), (k, v, s, g))
29
+ if initial_state is not None:
30
+ initial_state = tuple(map(lambda x: repeat(x, 'b h k v -> b (h g) k v', g=NG), initial_state))
31
+
32
+ B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
33
+
34
+ hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device)
35
+ ok = torch.zeros_like(s)
36
+
37
+ if scale is None:
38
+ scale = q.shape[-1] ** -0.5
39
+
40
+ final_state = None
41
+ if initial_state is not None:
42
+ hk += initial_state[0]
43
+
44
+ for i in range(T):
45
+ q_i = q[:, :, i] * scale
46
+ k_i = k[:, :, i]
47
+ v_i = s[:, :, i]
48
+ g_i = g[:, :, i].exp()
49
+ hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :]
50
+ ok[:, :, i] = (q_i[..., None] * hk).sum(-2)
51
+
52
+ qv = ok.softmax(-1)
53
+ hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device)
54
+ ov = torch.zeros_like(v)
55
+ if initial_state is not None:
56
+ hv += initial_state[1]
57
+
58
+ for i in range(T):
59
+ q_i = qv[:, :, i]
60
+ k_i = s[:, :, i]
61
+ v_i = v[:, :, i]
62
+ g_i = g[:, :, i].exp()
63
+ hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :]
64
+ ov[:, :, i] = (q_i[..., None] * hv).sum(-2)
65
+
66
+ if output_final_state:
67
+ final_state = (hk.view(B, -1, NG, K, M)[:, :, 0], hv.view(B, -1, NG, M, V)[:, :, 0])
68
+ return ov.to(dtype), final_state
fla/ops/hgrn/fused_recurrent.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp
11
+ from fla.utils import input_guard
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
16
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({'BD': BD}, num_warps=num_warps)
22
+ for BD in [32, 64, 128]
23
+ for num_warps in [1, 2, 4, 8]
24
+ ],
25
+ key=['D']
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def fused_recurrent_hgrn_fwd_kernel(
29
+ x,
30
+ g,
31
+ o,
32
+ h0,
33
+ ht,
34
+ offsets,
35
+ T,
36
+ D: tl.constexpr,
37
+ BD: tl.constexpr,
38
+ USE_INITIAL_STATE: tl.constexpr,
39
+ STORE_FINAL_STATE: tl.constexpr,
40
+ USE_OFFSETS: tl.constexpr
41
+ ):
42
+ i_d, i_n = tl.program_id(0), tl.program_id(1)
43
+ if USE_OFFSETS:
44
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
45
+ T = eos - bos
46
+ else:
47
+ bos, eos = i_n * T, i_n * T + T
48
+
49
+ o_d = i_d * BD + tl.arange(0, BD)
50
+ mask = o_d < D
51
+
52
+ p_x = x + bos * D + o_d
53
+ p_g = g + bos * D + o_d
54
+ p_o = o + bos * D + o_d
55
+
56
+ b_h = tl.zeros([BD], dtype=tl.float32)
57
+ if USE_INITIAL_STATE:
58
+ p_h0 = h0 + i_n * D + o_d
59
+ b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32)
60
+ for _ in range(0, T):
61
+ b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32)
62
+ b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
63
+ b_h = exp(b_g) * b_h + b_x
64
+ tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask)
65
+
66
+ p_x += D
67
+ p_g += D
68
+ p_o += D
69
+
70
+ if STORE_FINAL_STATE:
71
+ p_ht = ht + i_n * D + o_d
72
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask)
73
+
74
+
75
+ @triton.heuristics({
76
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
77
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
78
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
79
+ })
80
+ @triton.autotune(
81
+ configs=[
82
+ triton.Config({'BD': BD}, num_warps=num_warps)
83
+ for BD in [32, 64, 128]
84
+ for num_warps in [1, 2, 4, 8]
85
+ ],
86
+ key=['D']
87
+ )
88
+ @triton.jit(do_not_specialize=['T'])
89
+ def fused_recurrent_hgrn_bwd_kernel(
90
+ g,
91
+ o,
92
+ h0,
93
+ dx,
94
+ dg,
95
+ do,
96
+ dht,
97
+ dh0,
98
+ offsets,
99
+ T,
100
+ D: tl.constexpr,
101
+ BD: tl.constexpr,
102
+ USE_INITIAL_STATE: tl.constexpr,
103
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
104
+ USE_OFFSETS: tl.constexpr
105
+ ):
106
+ i_d, i_n = tl.program_id(0), tl.program_id(1)
107
+ if USE_OFFSETS:
108
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
109
+ T = eos - bos
110
+ else:
111
+ bos, eos = i_n * T, i_n * T + T
112
+
113
+ o_d = i_d * BD + tl.arange(0, BD)
114
+ mask = o_d < D
115
+
116
+ p_g = g + (bos + T - 1) * D + o_d
117
+ p_o = o + (bos + T - 2) * D + o_d
118
+ p_dx = dx + (bos + T - 1) * D + o_d
119
+ p_dg = dg + (bos + T - 1) * D + o_d
120
+ p_do = do + (bos + T - 1) * D + o_d
121
+
122
+ b_dh = tl.zeros([BD], dtype=tl.float32)
123
+ if USE_FINAL_STATE_GRADIENT:
124
+ p_dht = dht + i_n * D + o_d
125
+ b_dh += tl.load(p_dht, mask=mask, other=0).to(tl.float32)
126
+
127
+ for i in range(T - 1, -1, -1):
128
+ b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
129
+ b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)
130
+ if i > 0:
131
+ b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32)
132
+ elif USE_INITIAL_STATE:
133
+ b_o = tl.load(h0 + i_n * D + o_d, mask=mask, other=0).to(tl.float32)
134
+ else:
135
+ b_o = tl.zeros([BD], dtype=tl.float32)
136
+
137
+ b_dh = b_dh + b_do
138
+ b_dx = b_dh
139
+ b_dh = b_dh * exp(b_g)
140
+ b_dg = b_dh * b_o
141
+ tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
142
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask)
143
+
144
+ p_g -= D
145
+ p_o -= D
146
+ p_dx -= D
147
+ p_dg -= D
148
+ p_do -= D
149
+
150
+ if USE_INITIAL_STATE:
151
+ p_dh0 = dh0 + i_n * D + o_d
152
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask)
153
+
154
+
155
+ def fused_recurrent_hgrn_fwd(
156
+ x: torch.Tensor,
157
+ g: torch.Tensor,
158
+ initial_state: torch.Tensor = None,
159
+ output_final_state: bool = False,
160
+ offsets: Optional[torch.LongTensor] = None,
161
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
162
+ B, T, D = x.shape
163
+ N = B if offsets is None else len(offsets) - 1
164
+
165
+ o = torch.empty_like(x)
166
+ final_state = x.new_empty(N, D) if output_final_state else None
167
+
168
+ def grid(meta): return (triton.cdiv(D, meta['BD']), N)
169
+ fused_recurrent_hgrn_fwd_kernel[grid](
170
+ x=x,
171
+ g=g,
172
+ o=o,
173
+ h0=initial_state,
174
+ ht=final_state,
175
+ offsets=offsets,
176
+ T=T,
177
+ D=D
178
+ )
179
+ return o, final_state
180
+
181
+
182
+ def fused_recurrent_hgrn_bwd(
183
+ g: torch.Tensor,
184
+ o: torch.Tensor,
185
+ do: torch.Tensor,
186
+ dht: torch.Tensor = None,
187
+ initial_state: torch.Tensor = None,
188
+ offsets: Optional[torch.LongTensor] = None
189
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
190
+ B, T, D = do.shape
191
+ N = B if offsets is None else len(offsets) - 1
192
+
193
+ dx = torch.empty_like(o, dtype=torch.float)
194
+ dg = torch.empty_like(g, dtype=torch.float)
195
+ dh0 = torch.empty_like(initial_state, dtype=torch.float) if initial_state is not None else None
196
+ def grid(meta): return (triton.cdiv(D, meta['BD']), N)
197
+ fused_recurrent_hgrn_bwd_kernel[grid](
198
+ g=g,
199
+ o=o,
200
+ h0=initial_state,
201
+ dx=dx,
202
+ dg=dg,
203
+ do=do,
204
+ dht=dht,
205
+ dh0=dh0,
206
+ offsets=offsets,
207
+ T=T,
208
+ D=D
209
+ )
210
+ return dx, dg, dh0
211
+
212
+
213
+ class FusedRecurrentHGRNFunction(torch.autograd.Function):
214
+
215
+ @staticmethod
216
+ @input_guard
217
+ def forward(
218
+ ctx,
219
+ x: torch.Tensor,
220
+ g: torch.Tensor,
221
+ initial_state: torch.Tensor = None,
222
+ output_final_state: bool = False,
223
+ offsets: Optional[torch.LongTensor] = None
224
+ ):
225
+ o, ht = fused_recurrent_hgrn_fwd(
226
+ x=x,
227
+ g=g,
228
+ initial_state=initial_state,
229
+ output_final_state=output_final_state,
230
+ offsets=offsets
231
+ )
232
+ ctx.save_for_backward(g, o, initial_state)
233
+ ctx.offsets = offsets
234
+ return o, ht
235
+
236
+ @staticmethod
237
+ @input_guard
238
+ def backward(ctx, do, dht=None):
239
+ g, o, initial_state = ctx.saved_tensors
240
+ offsets = ctx.offsets
241
+
242
+ dx, dg, dh0 = fused_recurrent_hgrn_bwd(
243
+ g=g,
244
+ o=o,
245
+ do=do,
246
+ dht=dht,
247
+ initial_state=initial_state,
248
+ offsets=offsets
249
+ )
250
+ return dx, dg, dh0, None, None
251
+
252
+
253
+ @torch.compiler.disable
254
+ def fused_recurrent_hgrn(
255
+ x: torch.Tensor,
256
+ g: torch.Tensor,
257
+ initial_state: torch.Tensor = None,
258
+ output_final_state: bool = False,
259
+ cu_seqlens: Optional[torch.LongTensor] = None,
260
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
261
+ r"""
262
+ Args:
263
+ x (torch.Tensor):
264
+ inputs of shape `[B, T, D].
265
+ g (torch.Tensor):
266
+ Forget gates of shape `[B, T, D]`.
267
+ initial_state (Optional[torch.Tensor]):
268
+ Initial state of shape `[N, D]` for `N` input sequences.
269
+ For equal-length input sequences, `N` equals the batch size `B`.
270
+ Default: `None`.
271
+ output_final_state (Optional[bool]):
272
+ Whether to output the final state of shape `[N, D]`. Default: `False`.
273
+ cu_seqlens (torch.LongTensor):
274
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
275
+ consistent with the FlashAttention API.
276
+
277
+ Returns:
278
+ o (torch.Tensor):
279
+ Outputs of shape `[B, T, D]`.
280
+ final_state (torch.Tensor):
281
+ Final state of shape `[N, D]` if `output_final_state=True` else `None`.
282
+
283
+ Examples::
284
+ >>> import torch
285
+ >>> import torch.nn.functional as F
286
+ >>> from einops import rearrange
287
+ >>> from fla.ops.hgrn import fused_recurrent_hgrn
288
+ # inputs with equal lengths
289
+ >>> B, T, D = 4, 2048, 512
290
+ >>> x = torch.randn(B, T, D, device='cuda')
291
+ >>> g = F.logsigmoid(torch.randn(B, T, D, device='cuda'))
292
+ >>> h0 = torch.randn(B, D, device='cuda')
293
+ >>> o, ht = fused_recurrent_hgrn(x, g, initial_state=h0, output_final_state=True)
294
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
295
+ >>> x, g = map(lambda x: rearrange(x, 'b t d -> 1 (b t) d'), (x, g))
296
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
297
+ >>> cu_seqlens = x.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
298
+ >>> o_var, ht_var = fused_recurrent_hgrn(x, g, initial_state=h0, output_final_state=True, cu_seqlens=cu_seqlens)
299
+ >>> assert o.allclose(o_var.view(o.shape))
300
+ >>> assert ht.allclose(ht_var)
301
+ """
302
+ return FusedRecurrentHGRNFunction.apply(
303
+ x,
304
+ g,
305
+ initial_state,
306
+ output_final_state,
307
+ cu_seqlens
308
+ )
fla/ops/hgrn/naive.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+
8
+ def naive_recurrent_hgrn(
9
+ x: torch.Tensor,
10
+ g: torch.Tensor,
11
+ initial_state: Optional[torch.Tensor] = None,
12
+ output_final_state: Optional[bool] = False
13
+ ) -> torch.Tensor:
14
+ dtype = x.dtype
15
+ x, g = map(lambda i: i.float(), (x, g))
16
+ B, T, D = x.shape
17
+
18
+ h = torch.zeros(B, D, dtype=torch.float, device=x.device)
19
+ o = torch.zeros_like(x)
20
+
21
+ final_state = None
22
+ if initial_state is not None:
23
+ h += initial_state
24
+
25
+ for i in range(T):
26
+ h = g[:, i].exp() * h + x[:, i]
27
+ o[:, i] = h
28
+
29
+ if output_final_state:
30
+ final_state = h
31
+ return o.to(dtype), final_state
32
+
33
+
34
+ def naive_chunk_hgrn(
35
+ x: torch.Tensor,
36
+ g: torch.Tensor,
37
+ initial_state: Optional[torch.Tensor] = None,
38
+ output_final_state: Optional[bool] = False,
39
+ chunk_size: int = 64
40
+ ) -> torch.Tensor:
41
+ dtype = x.dtype
42
+ x, g = map(lambda i: i.float(), (x, g))
43
+ B, T, D = x.shape
44
+
45
+ gc = g.view(B, chunk_size, D).cumsum(-2).view_as(g)
46
+ h = torch.zeros(B, D, dtype=torch.float, device=x.device)
47
+ o = torch.zeros_like(x)
48
+
49
+ final_state = None
50
+ if initial_state is not None:
51
+ h += initial_state
52
+
53
+ for i in range(0, T, chunk_size):
54
+ hp = h
55
+ h = torch.zeros(B, D, dtype=torch.float, device=x.device)
56
+ for j in range(i, i + chunk_size):
57
+ h = g[:, j].exp() * h + x[:, j]
58
+ o[:, j] = hp * gc[:, j].exp() + h
59
+ h = o[:, j].clone()
60
+
61
+ if output_final_state:
62
+ final_state = h
63
+ return o.to(dtype), final_state
fla/ops/rwkv4/fused_recurrent.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Any, cast
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from torch import Tensor
10
+ from torch.autograd.function import Function, FunctionCtx, once_differentiable
11
+
12
+ from fla.ops.utils.op import exp
13
+
14
+
15
+ def get_block_size_c(chans: int) -> int:
16
+ if chans < 32:
17
+ return 32
18
+ if chans < 64:
19
+ return 64
20
+ return 128
21
+
22
+
23
+ @triton.jit
24
+ def fused_recurrent_rwkv4_forward_kernel(
25
+ # W
26
+ w_ptr,
27
+ w_s_c,
28
+ # U
29
+ u_ptr,
30
+ u_s_c,
31
+ # K
32
+ k_ptr,
33
+ k_s_b,
34
+ k_s_t,
35
+ k_s_c,
36
+ # V
37
+ v_ptr,
38
+ v_s_b,
39
+ v_s_t,
40
+ v_s_c,
41
+ # State
42
+ state_ptr,
43
+ state_s_b,
44
+ state_s_abe,
45
+ state_s_c,
46
+ # WKV
47
+ wkv_ptr,
48
+ wkv_s_b,
49
+ wkv_s_t,
50
+ wkv_s_c,
51
+ # Output state
52
+ state_out_ptr,
53
+ state_out_s_b,
54
+ state_out_s_abe,
55
+ state_out_s_t,
56
+ state_out_s_c,
57
+ # Params
58
+ chans,
59
+ tsz,
60
+ BLOCK_SIZE_C: tl.constexpr,
61
+ ):
62
+ # Parallelize over the batch dimension.
63
+ b_idx = tl.program_id(0)
64
+ c_idx = tl.program_id(1)
65
+
66
+ cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C)
67
+ cmask = cs < chans
68
+
69
+ # Pointers to the batch (and possibly channel) for the input tensors.
70
+ k_ptr = k_ptr + b_idx * k_s_b
71
+ v_ptr = v_ptr + b_idx * v_s_b
72
+ alpha_ptr = state_ptr + b_idx * state_s_b
73
+ beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe
74
+ eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe
75
+
76
+ # Pointers to the batch (and possibly channel) for the output tensors.
77
+ wkv_ptr = wkv_ptr + b_idx * wkv_s_b
78
+ alpha_out_ptr = state_out_ptr + b_idx * state_out_s_b
79
+ beta_out_ptr = state_out_ptr + b_idx * state_out_s_b + state_out_s_abe
80
+ eps_out_ptr = state_out_ptr + b_idx * state_out_s_b + 2 * state_out_s_abe
81
+
82
+ # Loads parameters.
83
+ alpha = tl.load(alpha_ptr + cs * state_s_c, mask=cmask).to(tl.float32)
84
+ beta = tl.load(beta_ptr + cs * state_s_c, mask=cmask).to(tl.float32)
85
+ eps = tl.load(eps_ptr + cs * state_s_c, mask=cmask).to(tl.float32)
86
+ w = tl.load(w_ptr + cs * w_s_c, mask=cmask).to(tl.float32)
87
+ u = tl.load(u_ptr + cs * u_s_c, mask=cmask).to(tl.float32)
88
+
89
+ for t in range(tsz):
90
+ kt = tl.load(k_ptr + t * k_s_t + cs * k_s_c, mask=cmask).to(tl.float32)
91
+ vt = tl.load(v_ptr + t * v_s_t + cs * v_s_c, mask=cmask).to(tl.float32)
92
+
93
+ ukt = u + kt
94
+ tau = tl.maximum(ukt, eps)
95
+ e1a = exp(eps - tau)
96
+ e2a = exp(ukt - tau)
97
+ wkv = (e1a * alpha + e2a * vt) / (e1a * beta + e2a)
98
+ tl.store(wkv_ptr + t * wkv_s_t + cs * wkv_s_c, wkv, mask=cmask)
99
+
100
+ w_eps = w + eps
101
+ eps = tl.maximum(w_eps, kt)
102
+ e1b = exp(w_eps - eps)
103
+ e2b = exp(kt - eps)
104
+ alpha = e1b * alpha + e2b * vt
105
+ beta = e1b * beta + e2b
106
+ tl.store(alpha_out_ptr + t * state_out_s_t + cs * state_out_s_c, alpha, mask=cmask)
107
+ tl.store(beta_out_ptr + t * state_out_s_t + cs * state_out_s_c, beta, mask=cmask)
108
+ tl.store(eps_out_ptr + t * state_out_s_t + cs * state_out_s_c, eps, mask=cmask)
109
+
110
+
111
+ def fused_recurrent_rwkv4_forward(
112
+ w: Tensor,
113
+ u: Tensor,
114
+ k: Tensor,
115
+ v: Tensor,
116
+ state: Tensor,
117
+ ) -> tuple[Tensor, Tensor]:
118
+ (bsz, tsz, chans) = k.shape
119
+
120
+ # New tensors to output.
121
+ wkvs = k.new_empty(bsz, tsz, chans)
122
+ state_out = k.new_empty(bsz, 3, tsz, chans)
123
+
124
+ # Constants.
125
+ block_size_c = get_block_size_c(chans)
126
+
127
+ def grid(meta: dict[str, Any]) -> tuple[int, ...]:
128
+ return (bsz, triton.cdiv(chans, meta["BLOCK_SIZE_C"]))
129
+
130
+ fused_recurrent_rwkv4_forward_kernel[grid](
131
+ # W
132
+ w,
133
+ w.stride(0),
134
+ # U
135
+ u,
136
+ u.stride(0),
137
+ # K
138
+ k,
139
+ k.stride(0),
140
+ k.stride(1),
141
+ k.stride(2),
142
+ # V
143
+ v,
144
+ v.stride(0),
145
+ v.stride(1),
146
+ v.stride(2),
147
+ # State
148
+ state,
149
+ state.stride(0),
150
+ state.stride(1),
151
+ state.stride(3),
152
+ # WKV
153
+ wkvs,
154
+ wkvs.stride(0),
155
+ wkvs.stride(1),
156
+ wkvs.stride(2),
157
+ # Output state
158
+ state_out,
159
+ state_out.stride(0),
160
+ state_out.stride(1),
161
+ state_out.stride(2),
162
+ state_out.stride(3),
163
+ # Params
164
+ chans,
165
+ tsz,
166
+ BLOCK_SIZE_C=block_size_c,
167
+ )
168
+
169
+ state_out = torch.cat((state, state_out), dim=2)
170
+
171
+ return wkvs, state_out
172
+
173
+
174
+ @triton.jit
175
+ def fused_recurrent_rwkv4_backward_kernel(
176
+ # W
177
+ w_ptr,
178
+ w_s_c,
179
+ # U
180
+ u_ptr,
181
+ u_s_c,
182
+ # K
183
+ k_ptr,
184
+ k_s_b,
185
+ k_s_t,
186
+ k_s_c,
187
+ # V
188
+ v_ptr,
189
+ v_s_b,
190
+ v_s_t,
191
+ v_s_c,
192
+ # State
193
+ state_ptr,
194
+ state_s_b,
195
+ state_s_abe,
196
+ state_s_t,
197
+ state_s_c,
198
+ # WKV grad
199
+ gwkv_ptr,
200
+ gwkv_s_b,
201
+ gwkv_s_t,
202
+ gwkv_s_c,
203
+ # Output state grad
204
+ gstate_out_ptr,
205
+ gstate_out_s_b,
206
+ gstate_out_s_abe,
207
+ gstate_out_s_c,
208
+ # W grad
209
+ gw_ptr,
210
+ gw_s_c,
211
+ # U grad
212
+ gu_ptr,
213
+ gu_s_c,
214
+ # K grad
215
+ gk_ptr,
216
+ gk_s_b,
217
+ gk_s_t,
218
+ gk_s_c,
219
+ # V grad
220
+ gv_ptr,
221
+ gv_s_b,
222
+ gv_s_t,
223
+ gv_s_c,
224
+ # State grad
225
+ gstate_ptr,
226
+ gstate_s_b,
227
+ gstate_s_abe,
228
+ gstate_s_c,
229
+ # Params
230
+ tsz,
231
+ chans,
232
+ BLOCK_SIZE_C: tl.constexpr,
233
+ ):
234
+ # Parallelize over the batch dimension.
235
+ b_idx = tl.program_id(0)
236
+ c_idx = tl.program_id(1)
237
+
238
+ cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C)
239
+ cmask = cs < chans
240
+
241
+ # Pointers to the batch (and possibly channel) for the input tensors.
242
+ k_ptr = k_ptr + b_idx * k_s_b
243
+ v_ptr = v_ptr + b_idx * v_s_b
244
+ alpha_ptr = state_ptr + b_idx * state_s_b
245
+ beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe
246
+ eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe
247
+
248
+ # Pointers to the batch (and possibly channel) for the output tensors.
249
+ gk_ptr = gk_ptr + b_idx * gk_s_b
250
+ gv_ptr = gv_ptr + b_idx * gv_s_b
251
+
252
+ # Pointers to gradients which were recieved by the function.
253
+ gwkv_ptr = gwkv_ptr + b_idx * gwkv_s_b
254
+ galpha_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b
255
+ gbeta_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + gstate_out_s_abe
256
+ geps_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + 2 * gstate_out_s_abe
257
+
258
+ # Loads parameters.
259
+ galpha = tl.load(galpha_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)
260
+ gbeta = tl.load(gbeta_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)
261
+ geps = tl.load(geps_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)
262
+ w = tl.load(w_ptr + w_s_c * cs, mask=cmask).to(tl.float32)
263
+ u = tl.load(u_ptr + u_s_c * cs, mask=cmask).to(tl.float32)
264
+
265
+ # Gradient accumulators.
266
+ gw = tl.zeros_like(w)
267
+ gu = tl.zeros_like(u)
268
+
269
+ alpha_prev = tl.load(alpha_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
270
+ beta_prev = tl.load(beta_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
271
+ eps_prev = tl.load(eps_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
272
+
273
+ for t in range(tsz):
274
+ tc = tsz - t - 1
275
+
276
+ kt = tl.load(k_ptr + tc * k_s_t + k_s_c * cs, mask=cmask).to(tl.float32)
277
+ vt = tl.load(v_ptr + tc * v_s_t + v_s_c * cs, mask=cmask).to(tl.float32)
278
+
279
+ alpha_curr = alpha_prev
280
+ beta_curr = beta_prev
281
+ eps_curr = eps_prev
282
+
283
+ alpha_prev = tl.load(alpha_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
284
+ beta_prev = tl.load(beta_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
285
+ eps_prev = tl.load(eps_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
286
+
287
+ ukt = u + kt
288
+ tau = tl.maximum(ukt, eps_prev)
289
+ e1 = exp(eps_prev - tau)
290
+ e2 = exp(ukt - tau)
291
+
292
+ euke = exp(ukt + eps_prev - 2 * tau)
293
+
294
+ denom = e1 * beta_prev + e2
295
+ denom_sq = denom * denom
296
+
297
+ gwkvt = tl.load(gwkv_ptr + tc * gwkv_s_t + gwkv_s_c * cs, mask=cmask).to(tl.float32)
298
+
299
+ # Backpropagates wkv gradients.
300
+ guk = gwkvt * e2 * (e1 * beta_prev * vt - e1 * alpha_prev) / denom_sq
301
+ gu += guk
302
+ gk = guk
303
+ gv = gwkvt * e2 / denom
304
+
305
+ galpha_wkv = gwkvt * e1 / denom
306
+ gbeta_wkv = -gwkvt * e1 * (e2 * vt + e1 * alpha_prev) / denom_sq
307
+ geps_wkv_denom = e1 * beta_prev + e2
308
+ geps_wkv = gwkvt * euke * (alpha_prev - vt * beta_prev) / (geps_wkv_denom * geps_wkv_denom)
309
+
310
+ e1 = exp(w + eps_prev - eps_curr)
311
+ e2 = exp(kt - eps_curr)
312
+
313
+ # Backpropagates alpha gradients.
314
+ galpha_we = galpha * e1 * alpha_prev
315
+ gw += galpha_we
316
+ gk += galpha * e2 * vt
317
+ gv += galpha * e2
318
+ geps += galpha * -alpha_curr
319
+
320
+ # Backpropagates beta gradients.
321
+ gbeta_we = gbeta * e1 * beta_prev
322
+ gw += gbeta_we
323
+ gk += gbeta * e2
324
+ geps += gbeta * -beta_curr
325
+
326
+ # Backpropagates epsilon gradients.
327
+ geps_mask = w + eps_prev > kt
328
+ geps_we = tl.where(geps_mask, geps, tl.zeros_like(geps))
329
+ gw += geps_we
330
+ gk += tl.where(geps_mask, tl.zeros_like(geps), geps)
331
+
332
+ # Stores the gradients for k and v.
333
+ tl.store(gk_ptr + tc * gk_s_t + gk_s_c * cs, gk, mask=cmask)
334
+ tl.store(gv_ptr + tc * gv_s_t + gv_s_c * cs, gv, mask=cmask)
335
+
336
+ # Computes new gradients for alpha and beta.
337
+ galpha = galpha * e1 + galpha_wkv
338
+ gbeta = gbeta * e1 + gbeta_wkv
339
+ geps = galpha_we + gbeta_we + geps_we + geps_wkv
340
+
341
+ # Stores final gradients for alpha and beta.
342
+ galpha_ptr = gstate_ptr + b_idx * gstate_s_b
343
+ gbeta_ptr = gstate_ptr + b_idx * gstate_s_b + gstate_s_abe
344
+ geps_ptr = gstate_ptr + b_idx * gstate_s_b + 2 * gstate_s_abe
345
+ tl.store(galpha_ptr + gstate_s_c * cs, galpha, mask=cmask)
346
+ tl.store(gbeta_ptr + gstate_s_c * cs, gbeta, mask=cmask)
347
+ tl.store(geps_ptr + gstate_s_c * cs, geps, mask=cmask)
348
+
349
+ # Stores final gradients for w and u.
350
+ gw_temp = tl.load(gw_ptr + gw_s_c * cs, mask=cmask).to(tl.float32)
351
+ gw_temp += gw
352
+ tl.store(gw_ptr + gw_s_c * cs, gw_temp, mask=cmask)
353
+ gu_temp = tl.load(gu_ptr + gu_s_c * cs, mask=cmask).to(tl.float32)
354
+ gu_temp += gu
355
+ tl.store(gu_ptr + gu_s_c * cs, gu_temp, mask=cmask)
356
+
357
+
358
+ def fused_recurrent_rwkv4_backward(
359
+ w: Tensor,
360
+ u: Tensor,
361
+ k: Tensor,
362
+ v: Tensor,
363
+ state: Tensor,
364
+ grad_wkv: Tensor,
365
+ grad_state: Tensor,
366
+ ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
367
+ bsz, tsz, chans = k.shape
368
+
369
+ gw = torch.zeros_like(w) # New tensors to output.
370
+ gu = torch.zeros_like(u)
371
+ gk = torch.empty_like(k)
372
+ gv = torch.empty_like(v)
373
+ gstate = k.new_empty(bsz, 3, 1, chans)
374
+
375
+ block_size_c = get_block_size_c(chans) # Constants.
376
+
377
+ def grid(meta: dict[str, Any]) -> tuple[int, ...]:
378
+ return (bsz, triton.cdiv(chans, meta["BLOCK_SIZE_C"]))
379
+
380
+ fused_recurrent_rwkv4_backward_kernel[grid](
381
+ # W
382
+ w,
383
+ w.stride(0),
384
+ # U
385
+ u,
386
+ u.stride(0),
387
+ # K
388
+ k,
389
+ k.stride(0),
390
+ k.stride(1),
391
+ k.stride(2),
392
+ # V
393
+ v,
394
+ v.stride(0),
395
+ v.stride(1),
396
+ v.stride(2),
397
+ # State
398
+ state,
399
+ state.stride(0),
400
+ state.stride(1),
401
+ state.stride(2),
402
+ state.stride(3),
403
+ # WKV grad
404
+ grad_wkv,
405
+ grad_wkv.stride(0),
406
+ grad_wkv.stride(1),
407
+ grad_wkv.stride(2),
408
+ # Output state grad
409
+ grad_state,
410
+ grad_state.stride(0),
411
+ grad_state.stride(1),
412
+ grad_state.stride(3),
413
+ # W grad
414
+ gw,
415
+ gw.stride(0),
416
+ # U grad
417
+ gu,
418
+ gu.stride(0),
419
+ # K grad
420
+ gk,
421
+ gk.stride(0),
422
+ gk.stride(1),
423
+ gk.stride(2),
424
+ # V grad
425
+ gv,
426
+ gv.stride(0),
427
+ gv.stride(1),
428
+ gv.stride(2),
429
+ # State grad
430
+ gstate,
431
+ gstate.stride(0),
432
+ gstate.stride(1),
433
+ gstate.stride(3),
434
+ # Params
435
+ tsz,
436
+ chans,
437
+ BLOCK_SIZE_C=block_size_c,
438
+ )
439
+
440
+ return gw, gu, gk, gv, gstate
441
+
442
+
443
+ class FusedRecurrentRWKV4Function(Function):
444
+ @staticmethod
445
+ def forward(
446
+ ctx: FunctionCtx,
447
+ w: Tensor,
448
+ u: Tensor,
449
+ k: Tensor,
450
+ v: Tensor,
451
+ state: Tensor,
452
+ ) -> tuple[Tensor, Tensor]:
453
+ ctx.input_dtype = k.dtype
454
+
455
+ w = -torch.exp(w.float().contiguous())
456
+ if k.dtype == torch.float16:
457
+ u = u.float()
458
+ k = k.float()
459
+ v = v.float()
460
+ u = u.contiguous()
461
+ k = k.contiguous()
462
+ v = v.contiguous()
463
+ wkv, state_out = fused_recurrent_rwkv4_forward(w, u, k, v, state)
464
+ ctx.save_for_backward(w, u, k, v, state_out[:, :, :-1])
465
+ return wkv, state_out[:, :, -1:]
466
+
467
+ @staticmethod
468
+ @once_differentiable
469
+ def backward(ctx: FunctionCtx, gwkv: Tensor, gstate: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
470
+ w, u, k, v, state = cast(tuple[Tensor, ...], ctx.saved_tensors)
471
+ gw, gu, gk, gv, gstate = fused_recurrent_rwkv4_backward(w, u, k, v, state, gwkv, gstate)
472
+ return gw, gu, gk, gv, gstate
473
+
474
+
475
+ def fused_recurrent_rwkv4(w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor) -> tuple[Tensor, Tensor]:
476
+ return FusedRecurrentRWKV4Function.apply(w, u, k, v, state)
fla/ops/rwkv6/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_rwkv6
4
+ from .fused_recurrent import fused_recurrent_rwkv6
5
+
6
+ __all__ = [
7
+ 'chunk_rwkv6',
8
+ 'fused_recurrent_rwkv6'
9
+ ]