drbh commited on
Commit
b833fce
·
1 Parent(s): 79e0916

fix: update library for all ops

Browse files
build/torch25-cxx11-cu121-x86_64-linux/flash_attn/__init__.py CHANGED
@@ -1,25 +1,45 @@
1
- from typing import Optional
2
-
3
  import torch
4
-
5
  from ._ops import ops
6
 
 
7
  def mha_fwd(
8
  q: torch.Tensor,
9
  k: torch.Tensor,
10
  v: torch.Tensor,
11
- out: torch.Tensor,
12
- alibi_slopes: torch.Tensor,
13
- p_dropout: float,
14
- softmax_scale: float,
15
- is_causal: bool,
16
- window_size_left: int,
17
- window_size_right: int,
18
- softcap: float,
19
- return_softmax: bool,
20
- gen: Optional[torch.Generator],
21
- ) -> torch.Tensor:
22
- ops.mha_fwd(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  q,
24
  k,
25
  v,
@@ -34,4 +54,311 @@ def mha_fwd(
34
  return_softmax,
35
  gen,
36
  )
37
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
 
2
  import torch
 
3
  from ._ops import ops
4
 
5
+
6
  def mha_fwd(
7
  q: torch.Tensor,
8
  k: torch.Tensor,
9
  v: torch.Tensor,
10
+ out: Optional[torch.Tensor] = None,
11
+ alibi_slopes: Optional[torch.Tensor] = None,
12
+ p_dropout: float = 0.0,
13
+ softmax_scale: float = 1.0,
14
+ is_causal: bool = False,
15
+ window_size_left: int = -1,
16
+ window_size_right: int = -1,
17
+ softcap: float = 0.0,
18
+ return_softmax: bool = False,
19
+ gen: Optional[torch.Generator] = None,
20
+ ) -> List[torch.Tensor]:
21
+ """
22
+ Forward pass for multi-head attention.
23
+
24
+ Args:
25
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
26
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
27
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
28
+ out: Optional output tensor, same shape as q
29
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
30
+ p_dropout: Dropout probability
31
+ softmax_scale: Scale factor for softmax
32
+ is_causal: Whether to use causal attention
33
+ window_size_left: Window size for left context (-1 for unlimited)
34
+ window_size_right: Window size for right context (-1 for unlimited)
35
+ softcap: Soft cap for attention weights
36
+ return_softmax: Whether to return softmax weights
37
+ gen: Optional random number generator
38
+
39
+ Returns:
40
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
41
+ """
42
+ return ops.mha_fwd(
43
  q,
44
  k,
45
  v,
 
54
  return_softmax,
55
  gen,
56
  )
57
+
58
+
59
+ def mha_varlen_fwd(
60
+ q: torch.Tensor,
61
+ k: torch.Tensor,
62
+ v: torch.Tensor,
63
+ cu_seqlens_q: torch.Tensor,
64
+ cu_seqlens_k: torch.Tensor,
65
+ out: Optional[torch.Tensor] = None,
66
+ seqused_k: Optional[torch.Tensor] = None,
67
+ leftpad_k: Optional[torch.Tensor] = None,
68
+ block_table: Optional[torch.Tensor] = None,
69
+ alibi_slopes: Optional[torch.Tensor] = None,
70
+ max_seqlen_q: int = 0,
71
+ max_seqlen_k: int = 0,
72
+ p_dropout: float = 0.0,
73
+ softmax_scale: float = 1.0,
74
+ zero_tensors: bool = False,
75
+ is_causal: bool = False,
76
+ window_size_left: int = -1,
77
+ window_size_right: int = -1,
78
+ softcap: float = 0.0,
79
+ return_softmax: bool = False,
80
+ gen: Optional[torch.Generator] = None,
81
+ ) -> List[torch.Tensor]:
82
+ """
83
+ Forward pass for multi-head attention with variable sequence lengths.
84
+
85
+ Args:
86
+ q: Query tensor of shape [total_q, num_heads, head_size]
87
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
88
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
89
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
90
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
91
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
92
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
93
+ leftpad_k: Optional left padding for keys of shape [batch_size]
94
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
95
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
96
+ max_seqlen_q: Maximum sequence length for queries
97
+ max_seqlen_k: Maximum sequence length for keys
98
+ p_dropout: Dropout probability
99
+ softmax_scale: Scale factor for softmax
100
+ zero_tensors: Whether to zero tensors before computation
101
+ is_causal: Whether to use causal attention
102
+ window_size_left: Window size for left context (-1 for unlimited)
103
+ window_size_right: Window size for right context (-1 for unlimited)
104
+ softcap: Soft cap for attention weights
105
+ return_softmax: Whether to return softmax weights
106
+ gen: Optional random number generator
107
+
108
+ Returns:
109
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
110
+ """
111
+ return ops.mha_varlen_fwd(
112
+ q,
113
+ k,
114
+ v,
115
+ out,
116
+ cu_seqlens_q,
117
+ cu_seqlens_k,
118
+ seqused_k,
119
+ leftpad_k,
120
+ block_table,
121
+ alibi_slopes,
122
+ max_seqlen_q,
123
+ max_seqlen_k,
124
+ p_dropout,
125
+ softmax_scale,
126
+ zero_tensors,
127
+ is_causal,
128
+ window_size_left,
129
+ window_size_right,
130
+ softcap,
131
+ return_softmax,
132
+ gen,
133
+ )
134
+
135
+
136
+ def mha_bwd(
137
+ dout: torch.Tensor,
138
+ q: torch.Tensor,
139
+ k: torch.Tensor,
140
+ v: torch.Tensor,
141
+ out: torch.Tensor,
142
+ softmax_lse: torch.Tensor,
143
+ dq: Optional[torch.Tensor] = None,
144
+ dk: Optional[torch.Tensor] = None,
145
+ dv: Optional[torch.Tensor] = None,
146
+ alibi_slopes: Optional[torch.Tensor] = None,
147
+ p_dropout: float = 0.0,
148
+ softmax_scale: float = 1.0,
149
+ is_causal: bool = False,
150
+ window_size_left: int = -1,
151
+ window_size_right: int = -1,
152
+ softcap: float = 0.0,
153
+ deterministic: bool = False,
154
+ gen: Optional[torch.Generator] = None,
155
+ rng_state: Optional[torch.Tensor] = None,
156
+ ) -> List[torch.Tensor]:
157
+ """
158
+ Backward pass for multi-head attention.
159
+
160
+ Args:
161
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
162
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
163
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
164
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
165
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
166
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
167
+ dq: Optional gradient tensor for queries, same shape as q
168
+ dk: Optional gradient tensor for keys, same shape as k
169
+ dv: Optional gradient tensor for values, same shape as v
170
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
171
+ p_dropout: Dropout probability
172
+ softmax_scale: Scale factor for softmax
173
+ is_causal: Whether to use causal attention
174
+ window_size_left: Window size for left context (-1 for unlimited)
175
+ window_size_right: Window size for right context (-1 for unlimited)
176
+ softcap: Soft cap for attention weights
177
+ deterministic: Whether to use deterministic algorithms
178
+ gen: Optional random number generator
179
+ rng_state: Optional RNG state from forward pass
180
+
181
+ Returns:
182
+ List of tensors: [dq, dk, dv]
183
+ """
184
+ return ops.mha_bwd(
185
+ dout,
186
+ q,
187
+ k,
188
+ v,
189
+ out,
190
+ softmax_lse,
191
+ dq,
192
+ dk,
193
+ dv,
194
+ alibi_slopes,
195
+ p_dropout,
196
+ softmax_scale,
197
+ is_causal,
198
+ window_size_left,
199
+ window_size_right,
200
+ softcap,
201
+ deterministic,
202
+ gen,
203
+ rng_state,
204
+ )
205
+
206
+
207
+ def mha_varlen_bwd(
208
+ dout: torch.Tensor,
209
+ q: torch.Tensor,
210
+ k: torch.Tensor,
211
+ v: torch.Tensor,
212
+ out: torch.Tensor,
213
+ softmax_lse: torch.Tensor,
214
+ cu_seqlens_q: torch.Tensor,
215
+ cu_seqlens_k: torch.Tensor,
216
+ dq: Optional[torch.Tensor] = None,
217
+ dk: Optional[torch.Tensor] = None,
218
+ dv: Optional[torch.Tensor] = None,
219
+ alibi_slopes: Optional[torch.Tensor] = None,
220
+ max_seqlen_q: int = 0,
221
+ max_seqlen_k: int = 0,
222
+ p_dropout: float = 0.0,
223
+ softmax_scale: float = 1.0,
224
+ zero_tensors: bool = False,
225
+ is_causal: bool = False,
226
+ window_size_left: int = -1,
227
+ window_size_right: int = -1,
228
+ softcap: float = 0.0,
229
+ deterministic: bool = False,
230
+ gen: Optional[torch.Generator] = None,
231
+ rng_state: Optional[torch.Tensor] = None,
232
+ ) -> List[torch.Tensor]:
233
+ """
234
+ Backward pass for multi-head attention with variable sequence lengths.
235
+
236
+ Args:
237
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
238
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
239
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
240
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
241
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
242
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
243
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
244
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
245
+ dq: Optional gradient tensor for queries, same shape as q
246
+ dk: Optional gradient tensor for keys, same shape as k
247
+ dv: Optional gradient tensor for values, same shape as v
248
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
249
+ max_seqlen_q: Maximum sequence length for queries
250
+ max_seqlen_k: Maximum sequence length for keys
251
+ p_dropout: Dropout probability
252
+ softmax_scale: Scale factor for softmax
253
+ zero_tensors: Whether to zero tensors before computation
254
+ is_causal: Whether to use causal attention
255
+ window_size_left: Window size for left context (-1 for unlimited)
256
+ window_size_right: Window size for right context (-1 for unlimited)
257
+ softcap: Soft cap for attention weights
258
+ deterministic: Whether to use deterministic algorithms
259
+ gen: Optional random number generator
260
+ rng_state: Optional RNG state from forward pass
261
+
262
+ Returns:
263
+ List of tensors: [dq, dk, dv]
264
+ """
265
+ return ops.mha_varlen_bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ dq,
273
+ dk,
274
+ dv,
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ alibi_slopes,
278
+ max_seqlen_q,
279
+ max_seqlen_k,
280
+ p_dropout,
281
+ softmax_scale,
282
+ zero_tensors,
283
+ is_causal,
284
+ window_size_left,
285
+ window_size_right,
286
+ softcap,
287
+ deterministic,
288
+ gen,
289
+ rng_state,
290
+ )
291
+
292
+
293
+ def mha_fwd_kvcache(
294
+ q: torch.Tensor,
295
+ kcache: torch.Tensor,
296
+ vcache: torch.Tensor,
297
+ k: Optional[torch.Tensor] = None,
298
+ v: Optional[torch.Tensor] = None,
299
+ seqlens_k: Optional[torch.Tensor] = None,
300
+ rotary_cos: Optional[torch.Tensor] = None,
301
+ rotary_sin: Optional[torch.Tensor] = None,
302
+ cache_batch_idx: Optional[torch.Tensor] = None,
303
+ leftpad_k: Optional[torch.Tensor] = None,
304
+ block_table: Optional[torch.Tensor] = None,
305
+ alibi_slopes: Optional[torch.Tensor] = None,
306
+ out: Optional[torch.Tensor] = None,
307
+ softmax_scale: float = 1.0,
308
+ is_causal: bool = False,
309
+ window_size_left: int = -1,
310
+ window_size_right: int = -1,
311
+ softcap: float = 0.0,
312
+ is_rotary_interleaved: bool = False,
313
+ num_splits: int = 1,
314
+ ) -> List[torch.Tensor]:
315
+ """
316
+ Forward pass for multi-head attention with KV cache.
317
+
318
+ Args:
319
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
320
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
321
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
322
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
323
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
324
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
325
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
326
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
327
+ cache_batch_idx: Optional indices to index into the KV cache
328
+ leftpad_k: Optional left padding for keys of shape [batch_size]
329
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
330
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
331
+ out: Optional output tensor, same shape as q
332
+ softmax_scale: Scale factor for softmax
333
+ is_causal: Whether to use causal attention
334
+ window_size_left: Window size for left context (-1 for unlimited)
335
+ window_size_right: Window size for right context (-1 for unlimited)
336
+ softcap: Soft cap for attention weights
337
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
338
+ num_splits: Number of splits for computation
339
+
340
+ Returns:
341
+ List of tensors: [output, softmax_lse]
342
+ """
343
+ return ops.mha_fwd_kvcache(
344
+ q,
345
+ kcache,
346
+ vcache,
347
+ k,
348
+ v,
349
+ seqlens_k,
350
+ rotary_cos,
351
+ rotary_sin,
352
+ cache_batch_idx,
353
+ leftpad_k,
354
+ block_table,
355
+ alibi_slopes,
356
+ out,
357
+ softmax_scale,
358
+ is_causal,
359
+ window_size_left,
360
+ window_size_right,
361
+ softcap,
362
+ is_rotary_interleaved,
363
+ num_splits,
364
+ )
build/torch25-cxx11-cu124-x86_64-linux/flash_attn/__init__.py CHANGED
@@ -1,25 +1,45 @@
1
- from typing import Optional
2
-
3
  import torch
4
-
5
  from ._ops import ops
6
 
 
7
  def mha_fwd(
8
  q: torch.Tensor,
9
  k: torch.Tensor,
10
  v: torch.Tensor,
11
- out: torch.Tensor,
12
- alibi_slopes: torch.Tensor,
13
- p_dropout: float,
14
- softmax_scale: float,
15
- is_causal: bool,
16
- window_size_left: int,
17
- window_size_right: int,
18
- softcap: float,
19
- return_softmax: bool,
20
- gen: Optional[torch.Generator],
21
- ) -> torch.Tensor:
22
- ops.mha_fwd(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  q,
24
  k,
25
  v,
@@ -34,4 +54,311 @@ def mha_fwd(
34
  return_softmax,
35
  gen,
36
  )
37
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
 
2
  import torch
 
3
  from ._ops import ops
4
 
5
+
6
  def mha_fwd(
7
  q: torch.Tensor,
8
  k: torch.Tensor,
9
  v: torch.Tensor,
10
+ out: Optional[torch.Tensor] = None,
11
+ alibi_slopes: Optional[torch.Tensor] = None,
12
+ p_dropout: float = 0.0,
13
+ softmax_scale: float = 1.0,
14
+ is_causal: bool = False,
15
+ window_size_left: int = -1,
16
+ window_size_right: int = -1,
17
+ softcap: float = 0.0,
18
+ return_softmax: bool = False,
19
+ gen: Optional[torch.Generator] = None,
20
+ ) -> List[torch.Tensor]:
21
+ """
22
+ Forward pass for multi-head attention.
23
+
24
+ Args:
25
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
26
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
27
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
28
+ out: Optional output tensor, same shape as q
29
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
30
+ p_dropout: Dropout probability
31
+ softmax_scale: Scale factor for softmax
32
+ is_causal: Whether to use causal attention
33
+ window_size_left: Window size for left context (-1 for unlimited)
34
+ window_size_right: Window size for right context (-1 for unlimited)
35
+ softcap: Soft cap for attention weights
36
+ return_softmax: Whether to return softmax weights
37
+ gen: Optional random number generator
38
+
39
+ Returns:
40
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
41
+ """
42
+ return ops.mha_fwd(
43
  q,
44
  k,
45
  v,
 
54
  return_softmax,
55
  gen,
56
  )
57
+
58
+
59
+ def mha_varlen_fwd(
60
+ q: torch.Tensor,
61
+ k: torch.Tensor,
62
+ v: torch.Tensor,
63
+ cu_seqlens_q: torch.Tensor,
64
+ cu_seqlens_k: torch.Tensor,
65
+ out: Optional[torch.Tensor] = None,
66
+ seqused_k: Optional[torch.Tensor] = None,
67
+ leftpad_k: Optional[torch.Tensor] = None,
68
+ block_table: Optional[torch.Tensor] = None,
69
+ alibi_slopes: Optional[torch.Tensor] = None,
70
+ max_seqlen_q: int = 0,
71
+ max_seqlen_k: int = 0,
72
+ p_dropout: float = 0.0,
73
+ softmax_scale: float = 1.0,
74
+ zero_tensors: bool = False,
75
+ is_causal: bool = False,
76
+ window_size_left: int = -1,
77
+ window_size_right: int = -1,
78
+ softcap: float = 0.0,
79
+ return_softmax: bool = False,
80
+ gen: Optional[torch.Generator] = None,
81
+ ) -> List[torch.Tensor]:
82
+ """
83
+ Forward pass for multi-head attention with variable sequence lengths.
84
+
85
+ Args:
86
+ q: Query tensor of shape [total_q, num_heads, head_size]
87
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
88
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
89
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
90
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
91
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
92
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
93
+ leftpad_k: Optional left padding for keys of shape [batch_size]
94
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
95
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
96
+ max_seqlen_q: Maximum sequence length for queries
97
+ max_seqlen_k: Maximum sequence length for keys
98
+ p_dropout: Dropout probability
99
+ softmax_scale: Scale factor for softmax
100
+ zero_tensors: Whether to zero tensors before computation
101
+ is_causal: Whether to use causal attention
102
+ window_size_left: Window size for left context (-1 for unlimited)
103
+ window_size_right: Window size for right context (-1 for unlimited)
104
+ softcap: Soft cap for attention weights
105
+ return_softmax: Whether to return softmax weights
106
+ gen: Optional random number generator
107
+
108
+ Returns:
109
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
110
+ """
111
+ return ops.mha_varlen_fwd(
112
+ q,
113
+ k,
114
+ v,
115
+ out,
116
+ cu_seqlens_q,
117
+ cu_seqlens_k,
118
+ seqused_k,
119
+ leftpad_k,
120
+ block_table,
121
+ alibi_slopes,
122
+ max_seqlen_q,
123
+ max_seqlen_k,
124
+ p_dropout,
125
+ softmax_scale,
126
+ zero_tensors,
127
+ is_causal,
128
+ window_size_left,
129
+ window_size_right,
130
+ softcap,
131
+ return_softmax,
132
+ gen,
133
+ )
134
+
135
+
136
+ def mha_bwd(
137
+ dout: torch.Tensor,
138
+ q: torch.Tensor,
139
+ k: torch.Tensor,
140
+ v: torch.Tensor,
141
+ out: torch.Tensor,
142
+ softmax_lse: torch.Tensor,
143
+ dq: Optional[torch.Tensor] = None,
144
+ dk: Optional[torch.Tensor] = None,
145
+ dv: Optional[torch.Tensor] = None,
146
+ alibi_slopes: Optional[torch.Tensor] = None,
147
+ p_dropout: float = 0.0,
148
+ softmax_scale: float = 1.0,
149
+ is_causal: bool = False,
150
+ window_size_left: int = -1,
151
+ window_size_right: int = -1,
152
+ softcap: float = 0.0,
153
+ deterministic: bool = False,
154
+ gen: Optional[torch.Generator] = None,
155
+ rng_state: Optional[torch.Tensor] = None,
156
+ ) -> List[torch.Tensor]:
157
+ """
158
+ Backward pass for multi-head attention.
159
+
160
+ Args:
161
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
162
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
163
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
164
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
165
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
166
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
167
+ dq: Optional gradient tensor for queries, same shape as q
168
+ dk: Optional gradient tensor for keys, same shape as k
169
+ dv: Optional gradient tensor for values, same shape as v
170
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
171
+ p_dropout: Dropout probability
172
+ softmax_scale: Scale factor for softmax
173
+ is_causal: Whether to use causal attention
174
+ window_size_left: Window size for left context (-1 for unlimited)
175
+ window_size_right: Window size for right context (-1 for unlimited)
176
+ softcap: Soft cap for attention weights
177
+ deterministic: Whether to use deterministic algorithms
178
+ gen: Optional random number generator
179
+ rng_state: Optional RNG state from forward pass
180
+
181
+ Returns:
182
+ List of tensors: [dq, dk, dv]
183
+ """
184
+ return ops.mha_bwd(
185
+ dout,
186
+ q,
187
+ k,
188
+ v,
189
+ out,
190
+ softmax_lse,
191
+ dq,
192
+ dk,
193
+ dv,
194
+ alibi_slopes,
195
+ p_dropout,
196
+ softmax_scale,
197
+ is_causal,
198
+ window_size_left,
199
+ window_size_right,
200
+ softcap,
201
+ deterministic,
202
+ gen,
203
+ rng_state,
204
+ )
205
+
206
+
207
+ def mha_varlen_bwd(
208
+ dout: torch.Tensor,
209
+ q: torch.Tensor,
210
+ k: torch.Tensor,
211
+ v: torch.Tensor,
212
+ out: torch.Tensor,
213
+ softmax_lse: torch.Tensor,
214
+ cu_seqlens_q: torch.Tensor,
215
+ cu_seqlens_k: torch.Tensor,
216
+ dq: Optional[torch.Tensor] = None,
217
+ dk: Optional[torch.Tensor] = None,
218
+ dv: Optional[torch.Tensor] = None,
219
+ alibi_slopes: Optional[torch.Tensor] = None,
220
+ max_seqlen_q: int = 0,
221
+ max_seqlen_k: int = 0,
222
+ p_dropout: float = 0.0,
223
+ softmax_scale: float = 1.0,
224
+ zero_tensors: bool = False,
225
+ is_causal: bool = False,
226
+ window_size_left: int = -1,
227
+ window_size_right: int = -1,
228
+ softcap: float = 0.0,
229
+ deterministic: bool = False,
230
+ gen: Optional[torch.Generator] = None,
231
+ rng_state: Optional[torch.Tensor] = None,
232
+ ) -> List[torch.Tensor]:
233
+ """
234
+ Backward pass for multi-head attention with variable sequence lengths.
235
+
236
+ Args:
237
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
238
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
239
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
240
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
241
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
242
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
243
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
244
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
245
+ dq: Optional gradient tensor for queries, same shape as q
246
+ dk: Optional gradient tensor for keys, same shape as k
247
+ dv: Optional gradient tensor for values, same shape as v
248
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
249
+ max_seqlen_q: Maximum sequence length for queries
250
+ max_seqlen_k: Maximum sequence length for keys
251
+ p_dropout: Dropout probability
252
+ softmax_scale: Scale factor for softmax
253
+ zero_tensors: Whether to zero tensors before computation
254
+ is_causal: Whether to use causal attention
255
+ window_size_left: Window size for left context (-1 for unlimited)
256
+ window_size_right: Window size for right context (-1 for unlimited)
257
+ softcap: Soft cap for attention weights
258
+ deterministic: Whether to use deterministic algorithms
259
+ gen: Optional random number generator
260
+ rng_state: Optional RNG state from forward pass
261
+
262
+ Returns:
263
+ List of tensors: [dq, dk, dv]
264
+ """
265
+ return ops.mha_varlen_bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ dq,
273
+ dk,
274
+ dv,
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ alibi_slopes,
278
+ max_seqlen_q,
279
+ max_seqlen_k,
280
+ p_dropout,
281
+ softmax_scale,
282
+ zero_tensors,
283
+ is_causal,
284
+ window_size_left,
285
+ window_size_right,
286
+ softcap,
287
+ deterministic,
288
+ gen,
289
+ rng_state,
290
+ )
291
+
292
+
293
+ def mha_fwd_kvcache(
294
+ q: torch.Tensor,
295
+ kcache: torch.Tensor,
296
+ vcache: torch.Tensor,
297
+ k: Optional[torch.Tensor] = None,
298
+ v: Optional[torch.Tensor] = None,
299
+ seqlens_k: Optional[torch.Tensor] = None,
300
+ rotary_cos: Optional[torch.Tensor] = None,
301
+ rotary_sin: Optional[torch.Tensor] = None,
302
+ cache_batch_idx: Optional[torch.Tensor] = None,
303
+ leftpad_k: Optional[torch.Tensor] = None,
304
+ block_table: Optional[torch.Tensor] = None,
305
+ alibi_slopes: Optional[torch.Tensor] = None,
306
+ out: Optional[torch.Tensor] = None,
307
+ softmax_scale: float = 1.0,
308
+ is_causal: bool = False,
309
+ window_size_left: int = -1,
310
+ window_size_right: int = -1,
311
+ softcap: float = 0.0,
312
+ is_rotary_interleaved: bool = False,
313
+ num_splits: int = 1,
314
+ ) -> List[torch.Tensor]:
315
+ """
316
+ Forward pass for multi-head attention with KV cache.
317
+
318
+ Args:
319
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
320
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
321
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
322
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
323
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
324
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
325
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
326
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
327
+ cache_batch_idx: Optional indices to index into the KV cache
328
+ leftpad_k: Optional left padding for keys of shape [batch_size]
329
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
330
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
331
+ out: Optional output tensor, same shape as q
332
+ softmax_scale: Scale factor for softmax
333
+ is_causal: Whether to use causal attention
334
+ window_size_left: Window size for left context (-1 for unlimited)
335
+ window_size_right: Window size for right context (-1 for unlimited)
336
+ softcap: Soft cap for attention weights
337
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
338
+ num_splits: Number of splits for computation
339
+
340
+ Returns:
341
+ List of tensors: [output, softmax_lse]
342
+ """
343
+ return ops.mha_fwd_kvcache(
344
+ q,
345
+ kcache,
346
+ vcache,
347
+ k,
348
+ v,
349
+ seqlens_k,
350
+ rotary_cos,
351
+ rotary_sin,
352
+ cache_batch_idx,
353
+ leftpad_k,
354
+ block_table,
355
+ alibi_slopes,
356
+ out,
357
+ softmax_scale,
358
+ is_causal,
359
+ window_size_left,
360
+ window_size_right,
361
+ softcap,
362
+ is_rotary_interleaved,
363
+ num_splits,
364
+ )
build/torch25-cxx98-cu118-x86_64-linux/flash_attn/__init__.py CHANGED
@@ -1,25 +1,45 @@
1
- from typing import Optional
2
-
3
  import torch
4
-
5
  from ._ops import ops
6
 
 
7
  def mha_fwd(
8
  q: torch.Tensor,
9
  k: torch.Tensor,
10
  v: torch.Tensor,
11
- out: torch.Tensor,
12
- alibi_slopes: torch.Tensor,
13
- p_dropout: float,
14
- softmax_scale: float,
15
- is_causal: bool,
16
- window_size_left: int,
17
- window_size_right: int,
18
- softcap: float,
19
- return_softmax: bool,
20
- gen: Optional[torch.Generator],
21
- ) -> torch.Tensor:
22
- ops.mha_fwd(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  q,
24
  k,
25
  v,
@@ -34,4 +54,311 @@ def mha_fwd(
34
  return_softmax,
35
  gen,
36
  )
37
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
 
2
  import torch
 
3
  from ._ops import ops
4
 
5
+
6
  def mha_fwd(
7
  q: torch.Tensor,
8
  k: torch.Tensor,
9
  v: torch.Tensor,
10
+ out: Optional[torch.Tensor] = None,
11
+ alibi_slopes: Optional[torch.Tensor] = None,
12
+ p_dropout: float = 0.0,
13
+ softmax_scale: float = 1.0,
14
+ is_causal: bool = False,
15
+ window_size_left: int = -1,
16
+ window_size_right: int = -1,
17
+ softcap: float = 0.0,
18
+ return_softmax: bool = False,
19
+ gen: Optional[torch.Generator] = None,
20
+ ) -> List[torch.Tensor]:
21
+ """
22
+ Forward pass for multi-head attention.
23
+
24
+ Args:
25
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
26
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
27
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
28
+ out: Optional output tensor, same shape as q
29
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
30
+ p_dropout: Dropout probability
31
+ softmax_scale: Scale factor for softmax
32
+ is_causal: Whether to use causal attention
33
+ window_size_left: Window size for left context (-1 for unlimited)
34
+ window_size_right: Window size for right context (-1 for unlimited)
35
+ softcap: Soft cap for attention weights
36
+ return_softmax: Whether to return softmax weights
37
+ gen: Optional random number generator
38
+
39
+ Returns:
40
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
41
+ """
42
+ return ops.mha_fwd(
43
  q,
44
  k,
45
  v,
 
54
  return_softmax,
55
  gen,
56
  )
57
+
58
+
59
+ def mha_varlen_fwd(
60
+ q: torch.Tensor,
61
+ k: torch.Tensor,
62
+ v: torch.Tensor,
63
+ cu_seqlens_q: torch.Tensor,
64
+ cu_seqlens_k: torch.Tensor,
65
+ out: Optional[torch.Tensor] = None,
66
+ seqused_k: Optional[torch.Tensor] = None,
67
+ leftpad_k: Optional[torch.Tensor] = None,
68
+ block_table: Optional[torch.Tensor] = None,
69
+ alibi_slopes: Optional[torch.Tensor] = None,
70
+ max_seqlen_q: int = 0,
71
+ max_seqlen_k: int = 0,
72
+ p_dropout: float = 0.0,
73
+ softmax_scale: float = 1.0,
74
+ zero_tensors: bool = False,
75
+ is_causal: bool = False,
76
+ window_size_left: int = -1,
77
+ window_size_right: int = -1,
78
+ softcap: float = 0.0,
79
+ return_softmax: bool = False,
80
+ gen: Optional[torch.Generator] = None,
81
+ ) -> List[torch.Tensor]:
82
+ """
83
+ Forward pass for multi-head attention with variable sequence lengths.
84
+
85
+ Args:
86
+ q: Query tensor of shape [total_q, num_heads, head_size]
87
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
88
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
89
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
90
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
91
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
92
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
93
+ leftpad_k: Optional left padding for keys of shape [batch_size]
94
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
95
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
96
+ max_seqlen_q: Maximum sequence length for queries
97
+ max_seqlen_k: Maximum sequence length for keys
98
+ p_dropout: Dropout probability
99
+ softmax_scale: Scale factor for softmax
100
+ zero_tensors: Whether to zero tensors before computation
101
+ is_causal: Whether to use causal attention
102
+ window_size_left: Window size for left context (-1 for unlimited)
103
+ window_size_right: Window size for right context (-1 for unlimited)
104
+ softcap: Soft cap for attention weights
105
+ return_softmax: Whether to return softmax weights
106
+ gen: Optional random number generator
107
+
108
+ Returns:
109
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
110
+ """
111
+ return ops.mha_varlen_fwd(
112
+ q,
113
+ k,
114
+ v,
115
+ out,
116
+ cu_seqlens_q,
117
+ cu_seqlens_k,
118
+ seqused_k,
119
+ leftpad_k,
120
+ block_table,
121
+ alibi_slopes,
122
+ max_seqlen_q,
123
+ max_seqlen_k,
124
+ p_dropout,
125
+ softmax_scale,
126
+ zero_tensors,
127
+ is_causal,
128
+ window_size_left,
129
+ window_size_right,
130
+ softcap,
131
+ return_softmax,
132
+ gen,
133
+ )
134
+
135
+
136
+ def mha_bwd(
137
+ dout: torch.Tensor,
138
+ q: torch.Tensor,
139
+ k: torch.Tensor,
140
+ v: torch.Tensor,
141
+ out: torch.Tensor,
142
+ softmax_lse: torch.Tensor,
143
+ dq: Optional[torch.Tensor] = None,
144
+ dk: Optional[torch.Tensor] = None,
145
+ dv: Optional[torch.Tensor] = None,
146
+ alibi_slopes: Optional[torch.Tensor] = None,
147
+ p_dropout: float = 0.0,
148
+ softmax_scale: float = 1.0,
149
+ is_causal: bool = False,
150
+ window_size_left: int = -1,
151
+ window_size_right: int = -1,
152
+ softcap: float = 0.0,
153
+ deterministic: bool = False,
154
+ gen: Optional[torch.Generator] = None,
155
+ rng_state: Optional[torch.Tensor] = None,
156
+ ) -> List[torch.Tensor]:
157
+ """
158
+ Backward pass for multi-head attention.
159
+
160
+ Args:
161
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
162
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
163
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
164
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
165
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
166
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
167
+ dq: Optional gradient tensor for queries, same shape as q
168
+ dk: Optional gradient tensor for keys, same shape as k
169
+ dv: Optional gradient tensor for values, same shape as v
170
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
171
+ p_dropout: Dropout probability
172
+ softmax_scale: Scale factor for softmax
173
+ is_causal: Whether to use causal attention
174
+ window_size_left: Window size for left context (-1 for unlimited)
175
+ window_size_right: Window size for right context (-1 for unlimited)
176
+ softcap: Soft cap for attention weights
177
+ deterministic: Whether to use deterministic algorithms
178
+ gen: Optional random number generator
179
+ rng_state: Optional RNG state from forward pass
180
+
181
+ Returns:
182
+ List of tensors: [dq, dk, dv]
183
+ """
184
+ return ops.mha_bwd(
185
+ dout,
186
+ q,
187
+ k,
188
+ v,
189
+ out,
190
+ softmax_lse,
191
+ dq,
192
+ dk,
193
+ dv,
194
+ alibi_slopes,
195
+ p_dropout,
196
+ softmax_scale,
197
+ is_causal,
198
+ window_size_left,
199
+ window_size_right,
200
+ softcap,
201
+ deterministic,
202
+ gen,
203
+ rng_state,
204
+ )
205
+
206
+
207
+ def mha_varlen_bwd(
208
+ dout: torch.Tensor,
209
+ q: torch.Tensor,
210
+ k: torch.Tensor,
211
+ v: torch.Tensor,
212
+ out: torch.Tensor,
213
+ softmax_lse: torch.Tensor,
214
+ cu_seqlens_q: torch.Tensor,
215
+ cu_seqlens_k: torch.Tensor,
216
+ dq: Optional[torch.Tensor] = None,
217
+ dk: Optional[torch.Tensor] = None,
218
+ dv: Optional[torch.Tensor] = None,
219
+ alibi_slopes: Optional[torch.Tensor] = None,
220
+ max_seqlen_q: int = 0,
221
+ max_seqlen_k: int = 0,
222
+ p_dropout: float = 0.0,
223
+ softmax_scale: float = 1.0,
224
+ zero_tensors: bool = False,
225
+ is_causal: bool = False,
226
+ window_size_left: int = -1,
227
+ window_size_right: int = -1,
228
+ softcap: float = 0.0,
229
+ deterministic: bool = False,
230
+ gen: Optional[torch.Generator] = None,
231
+ rng_state: Optional[torch.Tensor] = None,
232
+ ) -> List[torch.Tensor]:
233
+ """
234
+ Backward pass for multi-head attention with variable sequence lengths.
235
+
236
+ Args:
237
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
238
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
239
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
240
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
241
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
242
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
243
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
244
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
245
+ dq: Optional gradient tensor for queries, same shape as q
246
+ dk: Optional gradient tensor for keys, same shape as k
247
+ dv: Optional gradient tensor for values, same shape as v
248
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
249
+ max_seqlen_q: Maximum sequence length for queries
250
+ max_seqlen_k: Maximum sequence length for keys
251
+ p_dropout: Dropout probability
252
+ softmax_scale: Scale factor for softmax
253
+ zero_tensors: Whether to zero tensors before computation
254
+ is_causal: Whether to use causal attention
255
+ window_size_left: Window size for left context (-1 for unlimited)
256
+ window_size_right: Window size for right context (-1 for unlimited)
257
+ softcap: Soft cap for attention weights
258
+ deterministic: Whether to use deterministic algorithms
259
+ gen: Optional random number generator
260
+ rng_state: Optional RNG state from forward pass
261
+
262
+ Returns:
263
+ List of tensors: [dq, dk, dv]
264
+ """
265
+ return ops.mha_varlen_bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ dq,
273
+ dk,
274
+ dv,
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ alibi_slopes,
278
+ max_seqlen_q,
279
+ max_seqlen_k,
280
+ p_dropout,
281
+ softmax_scale,
282
+ zero_tensors,
283
+ is_causal,
284
+ window_size_left,
285
+ window_size_right,
286
+ softcap,
287
+ deterministic,
288
+ gen,
289
+ rng_state,
290
+ )
291
+
292
+
293
+ def mha_fwd_kvcache(
294
+ q: torch.Tensor,
295
+ kcache: torch.Tensor,
296
+ vcache: torch.Tensor,
297
+ k: Optional[torch.Tensor] = None,
298
+ v: Optional[torch.Tensor] = None,
299
+ seqlens_k: Optional[torch.Tensor] = None,
300
+ rotary_cos: Optional[torch.Tensor] = None,
301
+ rotary_sin: Optional[torch.Tensor] = None,
302
+ cache_batch_idx: Optional[torch.Tensor] = None,
303
+ leftpad_k: Optional[torch.Tensor] = None,
304
+ block_table: Optional[torch.Tensor] = None,
305
+ alibi_slopes: Optional[torch.Tensor] = None,
306
+ out: Optional[torch.Tensor] = None,
307
+ softmax_scale: float = 1.0,
308
+ is_causal: bool = False,
309
+ window_size_left: int = -1,
310
+ window_size_right: int = -1,
311
+ softcap: float = 0.0,
312
+ is_rotary_interleaved: bool = False,
313
+ num_splits: int = 1,
314
+ ) -> List[torch.Tensor]:
315
+ """
316
+ Forward pass for multi-head attention with KV cache.
317
+
318
+ Args:
319
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
320
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
321
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
322
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
323
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
324
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
325
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
326
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
327
+ cache_batch_idx: Optional indices to index into the KV cache
328
+ leftpad_k: Optional left padding for keys of shape [batch_size]
329
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
330
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
331
+ out: Optional output tensor, same shape as q
332
+ softmax_scale: Scale factor for softmax
333
+ is_causal: Whether to use causal attention
334
+ window_size_left: Window size for left context (-1 for unlimited)
335
+ window_size_right: Window size for right context (-1 for unlimited)
336
+ softcap: Soft cap for attention weights
337
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
338
+ num_splits: Number of splits for computation
339
+
340
+ Returns:
341
+ List of tensors: [output, softmax_lse]
342
+ """
343
+ return ops.mha_fwd_kvcache(
344
+ q,
345
+ kcache,
346
+ vcache,
347
+ k,
348
+ v,
349
+ seqlens_k,
350
+ rotary_cos,
351
+ rotary_sin,
352
+ cache_batch_idx,
353
+ leftpad_k,
354
+ block_table,
355
+ alibi_slopes,
356
+ out,
357
+ softmax_scale,
358
+ is_causal,
359
+ window_size_left,
360
+ window_size_right,
361
+ softcap,
362
+ is_rotary_interleaved,
363
+ num_splits,
364
+ )
build/torch25-cxx98-cu121-x86_64-linux/flash_attn/__init__.py CHANGED
@@ -1,25 +1,45 @@
1
- from typing import Optional
2
-
3
  import torch
4
-
5
  from ._ops import ops
6
 
 
7
  def mha_fwd(
8
  q: torch.Tensor,
9
  k: torch.Tensor,
10
  v: torch.Tensor,
11
- out: torch.Tensor,
12
- alibi_slopes: torch.Tensor,
13
- p_dropout: float,
14
- softmax_scale: float,
15
- is_causal: bool,
16
- window_size_left: int,
17
- window_size_right: int,
18
- softcap: float,
19
- return_softmax: bool,
20
- gen: Optional[torch.Generator],
21
- ) -> torch.Tensor:
22
- ops.mha_fwd(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  q,
24
  k,
25
  v,
@@ -34,4 +54,311 @@ def mha_fwd(
34
  return_softmax,
35
  gen,
36
  )
37
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
 
2
  import torch
 
3
  from ._ops import ops
4
 
5
+
6
  def mha_fwd(
7
  q: torch.Tensor,
8
  k: torch.Tensor,
9
  v: torch.Tensor,
10
+ out: Optional[torch.Tensor] = None,
11
+ alibi_slopes: Optional[torch.Tensor] = None,
12
+ p_dropout: float = 0.0,
13
+ softmax_scale: float = 1.0,
14
+ is_causal: bool = False,
15
+ window_size_left: int = -1,
16
+ window_size_right: int = -1,
17
+ softcap: float = 0.0,
18
+ return_softmax: bool = False,
19
+ gen: Optional[torch.Generator] = None,
20
+ ) -> List[torch.Tensor]:
21
+ """
22
+ Forward pass for multi-head attention.
23
+
24
+ Args:
25
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
26
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
27
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
28
+ out: Optional output tensor, same shape as q
29
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
30
+ p_dropout: Dropout probability
31
+ softmax_scale: Scale factor for softmax
32
+ is_causal: Whether to use causal attention
33
+ window_size_left: Window size for left context (-1 for unlimited)
34
+ window_size_right: Window size for right context (-1 for unlimited)
35
+ softcap: Soft cap for attention weights
36
+ return_softmax: Whether to return softmax weights
37
+ gen: Optional random number generator
38
+
39
+ Returns:
40
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
41
+ """
42
+ return ops.mha_fwd(
43
  q,
44
  k,
45
  v,
 
54
  return_softmax,
55
  gen,
56
  )
57
+
58
+
59
+ def mha_varlen_fwd(
60
+ q: torch.Tensor,
61
+ k: torch.Tensor,
62
+ v: torch.Tensor,
63
+ cu_seqlens_q: torch.Tensor,
64
+ cu_seqlens_k: torch.Tensor,
65
+ out: Optional[torch.Tensor] = None,
66
+ seqused_k: Optional[torch.Tensor] = None,
67
+ leftpad_k: Optional[torch.Tensor] = None,
68
+ block_table: Optional[torch.Tensor] = None,
69
+ alibi_slopes: Optional[torch.Tensor] = None,
70
+ max_seqlen_q: int = 0,
71
+ max_seqlen_k: int = 0,
72
+ p_dropout: float = 0.0,
73
+ softmax_scale: float = 1.0,
74
+ zero_tensors: bool = False,
75
+ is_causal: bool = False,
76
+ window_size_left: int = -1,
77
+ window_size_right: int = -1,
78
+ softcap: float = 0.0,
79
+ return_softmax: bool = False,
80
+ gen: Optional[torch.Generator] = None,
81
+ ) -> List[torch.Tensor]:
82
+ """
83
+ Forward pass for multi-head attention with variable sequence lengths.
84
+
85
+ Args:
86
+ q: Query tensor of shape [total_q, num_heads, head_size]
87
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
88
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
89
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
90
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
91
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
92
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
93
+ leftpad_k: Optional left padding for keys of shape [batch_size]
94
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
95
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
96
+ max_seqlen_q: Maximum sequence length for queries
97
+ max_seqlen_k: Maximum sequence length for keys
98
+ p_dropout: Dropout probability
99
+ softmax_scale: Scale factor for softmax
100
+ zero_tensors: Whether to zero tensors before computation
101
+ is_causal: Whether to use causal attention
102
+ window_size_left: Window size for left context (-1 for unlimited)
103
+ window_size_right: Window size for right context (-1 for unlimited)
104
+ softcap: Soft cap for attention weights
105
+ return_softmax: Whether to return softmax weights
106
+ gen: Optional random number generator
107
+
108
+ Returns:
109
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
110
+ """
111
+ return ops.mha_varlen_fwd(
112
+ q,
113
+ k,
114
+ v,
115
+ out,
116
+ cu_seqlens_q,
117
+ cu_seqlens_k,
118
+ seqused_k,
119
+ leftpad_k,
120
+ block_table,
121
+ alibi_slopes,
122
+ max_seqlen_q,
123
+ max_seqlen_k,
124
+ p_dropout,
125
+ softmax_scale,
126
+ zero_tensors,
127
+ is_causal,
128
+ window_size_left,
129
+ window_size_right,
130
+ softcap,
131
+ return_softmax,
132
+ gen,
133
+ )
134
+
135
+
136
+ def mha_bwd(
137
+ dout: torch.Tensor,
138
+ q: torch.Tensor,
139
+ k: torch.Tensor,
140
+ v: torch.Tensor,
141
+ out: torch.Tensor,
142
+ softmax_lse: torch.Tensor,
143
+ dq: Optional[torch.Tensor] = None,
144
+ dk: Optional[torch.Tensor] = None,
145
+ dv: Optional[torch.Tensor] = None,
146
+ alibi_slopes: Optional[torch.Tensor] = None,
147
+ p_dropout: float = 0.0,
148
+ softmax_scale: float = 1.0,
149
+ is_causal: bool = False,
150
+ window_size_left: int = -1,
151
+ window_size_right: int = -1,
152
+ softcap: float = 0.0,
153
+ deterministic: bool = False,
154
+ gen: Optional[torch.Generator] = None,
155
+ rng_state: Optional[torch.Tensor] = None,
156
+ ) -> List[torch.Tensor]:
157
+ """
158
+ Backward pass for multi-head attention.
159
+
160
+ Args:
161
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
162
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
163
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
164
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
165
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
166
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
167
+ dq: Optional gradient tensor for queries, same shape as q
168
+ dk: Optional gradient tensor for keys, same shape as k
169
+ dv: Optional gradient tensor for values, same shape as v
170
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
171
+ p_dropout: Dropout probability
172
+ softmax_scale: Scale factor for softmax
173
+ is_causal: Whether to use causal attention
174
+ window_size_left: Window size for left context (-1 for unlimited)
175
+ window_size_right: Window size for right context (-1 for unlimited)
176
+ softcap: Soft cap for attention weights
177
+ deterministic: Whether to use deterministic algorithms
178
+ gen: Optional random number generator
179
+ rng_state: Optional RNG state from forward pass
180
+
181
+ Returns:
182
+ List of tensors: [dq, dk, dv]
183
+ """
184
+ return ops.mha_bwd(
185
+ dout,
186
+ q,
187
+ k,
188
+ v,
189
+ out,
190
+ softmax_lse,
191
+ dq,
192
+ dk,
193
+ dv,
194
+ alibi_slopes,
195
+ p_dropout,
196
+ softmax_scale,
197
+ is_causal,
198
+ window_size_left,
199
+ window_size_right,
200
+ softcap,
201
+ deterministic,
202
+ gen,
203
+ rng_state,
204
+ )
205
+
206
+
207
+ def mha_varlen_bwd(
208
+ dout: torch.Tensor,
209
+ q: torch.Tensor,
210
+ k: torch.Tensor,
211
+ v: torch.Tensor,
212
+ out: torch.Tensor,
213
+ softmax_lse: torch.Tensor,
214
+ cu_seqlens_q: torch.Tensor,
215
+ cu_seqlens_k: torch.Tensor,
216
+ dq: Optional[torch.Tensor] = None,
217
+ dk: Optional[torch.Tensor] = None,
218
+ dv: Optional[torch.Tensor] = None,
219
+ alibi_slopes: Optional[torch.Tensor] = None,
220
+ max_seqlen_q: int = 0,
221
+ max_seqlen_k: int = 0,
222
+ p_dropout: float = 0.0,
223
+ softmax_scale: float = 1.0,
224
+ zero_tensors: bool = False,
225
+ is_causal: bool = False,
226
+ window_size_left: int = -1,
227
+ window_size_right: int = -1,
228
+ softcap: float = 0.0,
229
+ deterministic: bool = False,
230
+ gen: Optional[torch.Generator] = None,
231
+ rng_state: Optional[torch.Tensor] = None,
232
+ ) -> List[torch.Tensor]:
233
+ """
234
+ Backward pass for multi-head attention with variable sequence lengths.
235
+
236
+ Args:
237
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
238
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
239
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
240
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
241
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
242
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
243
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
244
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
245
+ dq: Optional gradient tensor for queries, same shape as q
246
+ dk: Optional gradient tensor for keys, same shape as k
247
+ dv: Optional gradient tensor for values, same shape as v
248
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
249
+ max_seqlen_q: Maximum sequence length for queries
250
+ max_seqlen_k: Maximum sequence length for keys
251
+ p_dropout: Dropout probability
252
+ softmax_scale: Scale factor for softmax
253
+ zero_tensors: Whether to zero tensors before computation
254
+ is_causal: Whether to use causal attention
255
+ window_size_left: Window size for left context (-1 for unlimited)
256
+ window_size_right: Window size for right context (-1 for unlimited)
257
+ softcap: Soft cap for attention weights
258
+ deterministic: Whether to use deterministic algorithms
259
+ gen: Optional random number generator
260
+ rng_state: Optional RNG state from forward pass
261
+
262
+ Returns:
263
+ List of tensors: [dq, dk, dv]
264
+ """
265
+ return ops.mha_varlen_bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ dq,
273
+ dk,
274
+ dv,
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ alibi_slopes,
278
+ max_seqlen_q,
279
+ max_seqlen_k,
280
+ p_dropout,
281
+ softmax_scale,
282
+ zero_tensors,
283
+ is_causal,
284
+ window_size_left,
285
+ window_size_right,
286
+ softcap,
287
+ deterministic,
288
+ gen,
289
+ rng_state,
290
+ )
291
+
292
+
293
+ def mha_fwd_kvcache(
294
+ q: torch.Tensor,
295
+ kcache: torch.Tensor,
296
+ vcache: torch.Tensor,
297
+ k: Optional[torch.Tensor] = None,
298
+ v: Optional[torch.Tensor] = None,
299
+ seqlens_k: Optional[torch.Tensor] = None,
300
+ rotary_cos: Optional[torch.Tensor] = None,
301
+ rotary_sin: Optional[torch.Tensor] = None,
302
+ cache_batch_idx: Optional[torch.Tensor] = None,
303
+ leftpad_k: Optional[torch.Tensor] = None,
304
+ block_table: Optional[torch.Tensor] = None,
305
+ alibi_slopes: Optional[torch.Tensor] = None,
306
+ out: Optional[torch.Tensor] = None,
307
+ softmax_scale: float = 1.0,
308
+ is_causal: bool = False,
309
+ window_size_left: int = -1,
310
+ window_size_right: int = -1,
311
+ softcap: float = 0.0,
312
+ is_rotary_interleaved: bool = False,
313
+ num_splits: int = 1,
314
+ ) -> List[torch.Tensor]:
315
+ """
316
+ Forward pass for multi-head attention with KV cache.
317
+
318
+ Args:
319
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
320
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
321
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
322
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
323
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
324
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
325
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
326
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
327
+ cache_batch_idx: Optional indices to index into the KV cache
328
+ leftpad_k: Optional left padding for keys of shape [batch_size]
329
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
330
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
331
+ out: Optional output tensor, same shape as q
332
+ softmax_scale: Scale factor for softmax
333
+ is_causal: Whether to use causal attention
334
+ window_size_left: Window size for left context (-1 for unlimited)
335
+ window_size_right: Window size for right context (-1 for unlimited)
336
+ softcap: Soft cap for attention weights
337
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
338
+ num_splits: Number of splits for computation
339
+
340
+ Returns:
341
+ List of tensors: [output, softmax_lse]
342
+ """
343
+ return ops.mha_fwd_kvcache(
344
+ q,
345
+ kcache,
346
+ vcache,
347
+ k,
348
+ v,
349
+ seqlens_k,
350
+ rotary_cos,
351
+ rotary_sin,
352
+ cache_batch_idx,
353
+ leftpad_k,
354
+ block_table,
355
+ alibi_slopes,
356
+ out,
357
+ softmax_scale,
358
+ is_causal,
359
+ window_size_left,
360
+ window_size_right,
361
+ softcap,
362
+ is_rotary_interleaved,
363
+ num_splits,
364
+ )
build/torch25-cxx98-cu124-x86_64-linux/flash_attn/__init__.py CHANGED
@@ -1,25 +1,45 @@
1
- from typing import Optional
2
-
3
  import torch
4
-
5
  from ._ops import ops
6
 
 
7
  def mha_fwd(
8
  q: torch.Tensor,
9
  k: torch.Tensor,
10
  v: torch.Tensor,
11
- out: torch.Tensor,
12
- alibi_slopes: torch.Tensor,
13
- p_dropout: float,
14
- softmax_scale: float,
15
- is_causal: bool,
16
- window_size_left: int,
17
- window_size_right: int,
18
- softcap: float,
19
- return_softmax: bool,
20
- gen: Optional[torch.Generator],
21
- ) -> torch.Tensor:
22
- ops.mha_fwd(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  q,
24
  k,
25
  v,
@@ -34,4 +54,311 @@ def mha_fwd(
34
  return_softmax,
35
  gen,
36
  )
37
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
 
2
  import torch
 
3
  from ._ops import ops
4
 
5
+
6
  def mha_fwd(
7
  q: torch.Tensor,
8
  k: torch.Tensor,
9
  v: torch.Tensor,
10
+ out: Optional[torch.Tensor] = None,
11
+ alibi_slopes: Optional[torch.Tensor] = None,
12
+ p_dropout: float = 0.0,
13
+ softmax_scale: float = 1.0,
14
+ is_causal: bool = False,
15
+ window_size_left: int = -1,
16
+ window_size_right: int = -1,
17
+ softcap: float = 0.0,
18
+ return_softmax: bool = False,
19
+ gen: Optional[torch.Generator] = None,
20
+ ) -> List[torch.Tensor]:
21
+ """
22
+ Forward pass for multi-head attention.
23
+
24
+ Args:
25
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
26
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
27
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
28
+ out: Optional output tensor, same shape as q
29
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
30
+ p_dropout: Dropout probability
31
+ softmax_scale: Scale factor for softmax
32
+ is_causal: Whether to use causal attention
33
+ window_size_left: Window size for left context (-1 for unlimited)
34
+ window_size_right: Window size for right context (-1 for unlimited)
35
+ softcap: Soft cap for attention weights
36
+ return_softmax: Whether to return softmax weights
37
+ gen: Optional random number generator
38
+
39
+ Returns:
40
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
41
+ """
42
+ return ops.mha_fwd(
43
  q,
44
  k,
45
  v,
 
54
  return_softmax,
55
  gen,
56
  )
57
+
58
+
59
+ def mha_varlen_fwd(
60
+ q: torch.Tensor,
61
+ k: torch.Tensor,
62
+ v: torch.Tensor,
63
+ cu_seqlens_q: torch.Tensor,
64
+ cu_seqlens_k: torch.Tensor,
65
+ out: Optional[torch.Tensor] = None,
66
+ seqused_k: Optional[torch.Tensor] = None,
67
+ leftpad_k: Optional[torch.Tensor] = None,
68
+ block_table: Optional[torch.Tensor] = None,
69
+ alibi_slopes: Optional[torch.Tensor] = None,
70
+ max_seqlen_q: int = 0,
71
+ max_seqlen_k: int = 0,
72
+ p_dropout: float = 0.0,
73
+ softmax_scale: float = 1.0,
74
+ zero_tensors: bool = False,
75
+ is_causal: bool = False,
76
+ window_size_left: int = -1,
77
+ window_size_right: int = -1,
78
+ softcap: float = 0.0,
79
+ return_softmax: bool = False,
80
+ gen: Optional[torch.Generator] = None,
81
+ ) -> List[torch.Tensor]:
82
+ """
83
+ Forward pass for multi-head attention with variable sequence lengths.
84
+
85
+ Args:
86
+ q: Query tensor of shape [total_q, num_heads, head_size]
87
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
88
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
89
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
90
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
91
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
92
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
93
+ leftpad_k: Optional left padding for keys of shape [batch_size]
94
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
95
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
96
+ max_seqlen_q: Maximum sequence length for queries
97
+ max_seqlen_k: Maximum sequence length for keys
98
+ p_dropout: Dropout probability
99
+ softmax_scale: Scale factor for softmax
100
+ zero_tensors: Whether to zero tensors before computation
101
+ is_causal: Whether to use causal attention
102
+ window_size_left: Window size for left context (-1 for unlimited)
103
+ window_size_right: Window size for right context (-1 for unlimited)
104
+ softcap: Soft cap for attention weights
105
+ return_softmax: Whether to return softmax weights
106
+ gen: Optional random number generator
107
+
108
+ Returns:
109
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
110
+ """
111
+ return ops.mha_varlen_fwd(
112
+ q,
113
+ k,
114
+ v,
115
+ out,
116
+ cu_seqlens_q,
117
+ cu_seqlens_k,
118
+ seqused_k,
119
+ leftpad_k,
120
+ block_table,
121
+ alibi_slopes,
122
+ max_seqlen_q,
123
+ max_seqlen_k,
124
+ p_dropout,
125
+ softmax_scale,
126
+ zero_tensors,
127
+ is_causal,
128
+ window_size_left,
129
+ window_size_right,
130
+ softcap,
131
+ return_softmax,
132
+ gen,
133
+ )
134
+
135
+
136
+ def mha_bwd(
137
+ dout: torch.Tensor,
138
+ q: torch.Tensor,
139
+ k: torch.Tensor,
140
+ v: torch.Tensor,
141
+ out: torch.Tensor,
142
+ softmax_lse: torch.Tensor,
143
+ dq: Optional[torch.Tensor] = None,
144
+ dk: Optional[torch.Tensor] = None,
145
+ dv: Optional[torch.Tensor] = None,
146
+ alibi_slopes: Optional[torch.Tensor] = None,
147
+ p_dropout: float = 0.0,
148
+ softmax_scale: float = 1.0,
149
+ is_causal: bool = False,
150
+ window_size_left: int = -1,
151
+ window_size_right: int = -1,
152
+ softcap: float = 0.0,
153
+ deterministic: bool = False,
154
+ gen: Optional[torch.Generator] = None,
155
+ rng_state: Optional[torch.Tensor] = None,
156
+ ) -> List[torch.Tensor]:
157
+ """
158
+ Backward pass for multi-head attention.
159
+
160
+ Args:
161
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
162
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
163
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
164
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
165
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
166
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
167
+ dq: Optional gradient tensor for queries, same shape as q
168
+ dk: Optional gradient tensor for keys, same shape as k
169
+ dv: Optional gradient tensor for values, same shape as v
170
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
171
+ p_dropout: Dropout probability
172
+ softmax_scale: Scale factor for softmax
173
+ is_causal: Whether to use causal attention
174
+ window_size_left: Window size for left context (-1 for unlimited)
175
+ window_size_right: Window size for right context (-1 for unlimited)
176
+ softcap: Soft cap for attention weights
177
+ deterministic: Whether to use deterministic algorithms
178
+ gen: Optional random number generator
179
+ rng_state: Optional RNG state from forward pass
180
+
181
+ Returns:
182
+ List of tensors: [dq, dk, dv]
183
+ """
184
+ return ops.mha_bwd(
185
+ dout,
186
+ q,
187
+ k,
188
+ v,
189
+ out,
190
+ softmax_lse,
191
+ dq,
192
+ dk,
193
+ dv,
194
+ alibi_slopes,
195
+ p_dropout,
196
+ softmax_scale,
197
+ is_causal,
198
+ window_size_left,
199
+ window_size_right,
200
+ softcap,
201
+ deterministic,
202
+ gen,
203
+ rng_state,
204
+ )
205
+
206
+
207
+ def mha_varlen_bwd(
208
+ dout: torch.Tensor,
209
+ q: torch.Tensor,
210
+ k: torch.Tensor,
211
+ v: torch.Tensor,
212
+ out: torch.Tensor,
213
+ softmax_lse: torch.Tensor,
214
+ cu_seqlens_q: torch.Tensor,
215
+ cu_seqlens_k: torch.Tensor,
216
+ dq: Optional[torch.Tensor] = None,
217
+ dk: Optional[torch.Tensor] = None,
218
+ dv: Optional[torch.Tensor] = None,
219
+ alibi_slopes: Optional[torch.Tensor] = None,
220
+ max_seqlen_q: int = 0,
221
+ max_seqlen_k: int = 0,
222
+ p_dropout: float = 0.0,
223
+ softmax_scale: float = 1.0,
224
+ zero_tensors: bool = False,
225
+ is_causal: bool = False,
226
+ window_size_left: int = -1,
227
+ window_size_right: int = -1,
228
+ softcap: float = 0.0,
229
+ deterministic: bool = False,
230
+ gen: Optional[torch.Generator] = None,
231
+ rng_state: Optional[torch.Tensor] = None,
232
+ ) -> List[torch.Tensor]:
233
+ """
234
+ Backward pass for multi-head attention with variable sequence lengths.
235
+
236
+ Args:
237
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
238
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
239
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
240
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
241
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
242
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
243
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
244
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
245
+ dq: Optional gradient tensor for queries, same shape as q
246
+ dk: Optional gradient tensor for keys, same shape as k
247
+ dv: Optional gradient tensor for values, same shape as v
248
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
249
+ max_seqlen_q: Maximum sequence length for queries
250
+ max_seqlen_k: Maximum sequence length for keys
251
+ p_dropout: Dropout probability
252
+ softmax_scale: Scale factor for softmax
253
+ zero_tensors: Whether to zero tensors before computation
254
+ is_causal: Whether to use causal attention
255
+ window_size_left: Window size for left context (-1 for unlimited)
256
+ window_size_right: Window size for right context (-1 for unlimited)
257
+ softcap: Soft cap for attention weights
258
+ deterministic: Whether to use deterministic algorithms
259
+ gen: Optional random number generator
260
+ rng_state: Optional RNG state from forward pass
261
+
262
+ Returns:
263
+ List of tensors: [dq, dk, dv]
264
+ """
265
+ return ops.mha_varlen_bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ dq,
273
+ dk,
274
+ dv,
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ alibi_slopes,
278
+ max_seqlen_q,
279
+ max_seqlen_k,
280
+ p_dropout,
281
+ softmax_scale,
282
+ zero_tensors,
283
+ is_causal,
284
+ window_size_left,
285
+ window_size_right,
286
+ softcap,
287
+ deterministic,
288
+ gen,
289
+ rng_state,
290
+ )
291
+
292
+
293
+ def mha_fwd_kvcache(
294
+ q: torch.Tensor,
295
+ kcache: torch.Tensor,
296
+ vcache: torch.Tensor,
297
+ k: Optional[torch.Tensor] = None,
298
+ v: Optional[torch.Tensor] = None,
299
+ seqlens_k: Optional[torch.Tensor] = None,
300
+ rotary_cos: Optional[torch.Tensor] = None,
301
+ rotary_sin: Optional[torch.Tensor] = None,
302
+ cache_batch_idx: Optional[torch.Tensor] = None,
303
+ leftpad_k: Optional[torch.Tensor] = None,
304
+ block_table: Optional[torch.Tensor] = None,
305
+ alibi_slopes: Optional[torch.Tensor] = None,
306
+ out: Optional[torch.Tensor] = None,
307
+ softmax_scale: float = 1.0,
308
+ is_causal: bool = False,
309
+ window_size_left: int = -1,
310
+ window_size_right: int = -1,
311
+ softcap: float = 0.0,
312
+ is_rotary_interleaved: bool = False,
313
+ num_splits: int = 1,
314
+ ) -> List[torch.Tensor]:
315
+ """
316
+ Forward pass for multi-head attention with KV cache.
317
+
318
+ Args:
319
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
320
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
321
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
322
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
323
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
324
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
325
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
326
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
327
+ cache_batch_idx: Optional indices to index into the KV cache
328
+ leftpad_k: Optional left padding for keys of shape [batch_size]
329
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
330
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
331
+ out: Optional output tensor, same shape as q
332
+ softmax_scale: Scale factor for softmax
333
+ is_causal: Whether to use causal attention
334
+ window_size_left: Window size for left context (-1 for unlimited)
335
+ window_size_right: Window size for right context (-1 for unlimited)
336
+ softcap: Soft cap for attention weights
337
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
338
+ num_splits: Number of splits for computation
339
+
340
+ Returns:
341
+ List of tensors: [output, softmax_lse]
342
+ """
343
+ return ops.mha_fwd_kvcache(
344
+ q,
345
+ kcache,
346
+ vcache,
347
+ k,
348
+ v,
349
+ seqlens_k,
350
+ rotary_cos,
351
+ rotary_sin,
352
+ cache_batch_idx,
353
+ leftpad_k,
354
+ block_table,
355
+ alibi_slopes,
356
+ out,
357
+ softmax_scale,
358
+ is_causal,
359
+ window_size_left,
360
+ window_size_right,
361
+ softcap,
362
+ is_rotary_interleaved,
363
+ num_splits,
364
+ )
build/torch26-cxx11-cu118-x86_64-linux/flash_attn/__init__.py CHANGED
@@ -1,25 +1,45 @@
1
- from typing import Optional
2
-
3
  import torch
4
-
5
  from ._ops import ops
6
 
 
7
  def mha_fwd(
8
  q: torch.Tensor,
9
  k: torch.Tensor,
10
  v: torch.Tensor,
11
- out: torch.Tensor,
12
- alibi_slopes: torch.Tensor,
13
- p_dropout: float,
14
- softmax_scale: float,
15
- is_causal: bool,
16
- window_size_left: int,
17
- window_size_right: int,
18
- softcap: float,
19
- return_softmax: bool,
20
- gen: Optional[torch.Generator],
21
- ) -> torch.Tensor:
22
- ops.mha_fwd(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  q,
24
  k,
25
  v,
@@ -34,4 +54,311 @@ def mha_fwd(
34
  return_softmax,
35
  gen,
36
  )
37
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
 
2
  import torch
 
3
  from ._ops import ops
4
 
5
+
6
  def mha_fwd(
7
  q: torch.Tensor,
8
  k: torch.Tensor,
9
  v: torch.Tensor,
10
+ out: Optional[torch.Tensor] = None,
11
+ alibi_slopes: Optional[torch.Tensor] = None,
12
+ p_dropout: float = 0.0,
13
+ softmax_scale: float = 1.0,
14
+ is_causal: bool = False,
15
+ window_size_left: int = -1,
16
+ window_size_right: int = -1,
17
+ softcap: float = 0.0,
18
+ return_softmax: bool = False,
19
+ gen: Optional[torch.Generator] = None,
20
+ ) -> List[torch.Tensor]:
21
+ """
22
+ Forward pass for multi-head attention.
23
+
24
+ Args:
25
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
26
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
27
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
28
+ out: Optional output tensor, same shape as q
29
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
30
+ p_dropout: Dropout probability
31
+ softmax_scale: Scale factor for softmax
32
+ is_causal: Whether to use causal attention
33
+ window_size_left: Window size for left context (-1 for unlimited)
34
+ window_size_right: Window size for right context (-1 for unlimited)
35
+ softcap: Soft cap for attention weights
36
+ return_softmax: Whether to return softmax weights
37
+ gen: Optional random number generator
38
+
39
+ Returns:
40
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
41
+ """
42
+ return ops.mha_fwd(
43
  q,
44
  k,
45
  v,
 
54
  return_softmax,
55
  gen,
56
  )
57
+
58
+
59
+ def mha_varlen_fwd(
60
+ q: torch.Tensor,
61
+ k: torch.Tensor,
62
+ v: torch.Tensor,
63
+ cu_seqlens_q: torch.Tensor,
64
+ cu_seqlens_k: torch.Tensor,
65
+ out: Optional[torch.Tensor] = None,
66
+ seqused_k: Optional[torch.Tensor] = None,
67
+ leftpad_k: Optional[torch.Tensor] = None,
68
+ block_table: Optional[torch.Tensor] = None,
69
+ alibi_slopes: Optional[torch.Tensor] = None,
70
+ max_seqlen_q: int = 0,
71
+ max_seqlen_k: int = 0,
72
+ p_dropout: float = 0.0,
73
+ softmax_scale: float = 1.0,
74
+ zero_tensors: bool = False,
75
+ is_causal: bool = False,
76
+ window_size_left: int = -1,
77
+ window_size_right: int = -1,
78
+ softcap: float = 0.0,
79
+ return_softmax: bool = False,
80
+ gen: Optional[torch.Generator] = None,
81
+ ) -> List[torch.Tensor]:
82
+ """
83
+ Forward pass for multi-head attention with variable sequence lengths.
84
+
85
+ Args:
86
+ q: Query tensor of shape [total_q, num_heads, head_size]
87
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
88
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
89
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
90
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
91
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
92
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
93
+ leftpad_k: Optional left padding for keys of shape [batch_size]
94
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
95
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
96
+ max_seqlen_q: Maximum sequence length for queries
97
+ max_seqlen_k: Maximum sequence length for keys
98
+ p_dropout: Dropout probability
99
+ softmax_scale: Scale factor for softmax
100
+ zero_tensors: Whether to zero tensors before computation
101
+ is_causal: Whether to use causal attention
102
+ window_size_left: Window size for left context (-1 for unlimited)
103
+ window_size_right: Window size for right context (-1 for unlimited)
104
+ softcap: Soft cap for attention weights
105
+ return_softmax: Whether to return softmax weights
106
+ gen: Optional random number generator
107
+
108
+ Returns:
109
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
110
+ """
111
+ return ops.mha_varlen_fwd(
112
+ q,
113
+ k,
114
+ v,
115
+ out,
116
+ cu_seqlens_q,
117
+ cu_seqlens_k,
118
+ seqused_k,
119
+ leftpad_k,
120
+ block_table,
121
+ alibi_slopes,
122
+ max_seqlen_q,
123
+ max_seqlen_k,
124
+ p_dropout,
125
+ softmax_scale,
126
+ zero_tensors,
127
+ is_causal,
128
+ window_size_left,
129
+ window_size_right,
130
+ softcap,
131
+ return_softmax,
132
+ gen,
133
+ )
134
+
135
+
136
+ def mha_bwd(
137
+ dout: torch.Tensor,
138
+ q: torch.Tensor,
139
+ k: torch.Tensor,
140
+ v: torch.Tensor,
141
+ out: torch.Tensor,
142
+ softmax_lse: torch.Tensor,
143
+ dq: Optional[torch.Tensor] = None,
144
+ dk: Optional[torch.Tensor] = None,
145
+ dv: Optional[torch.Tensor] = None,
146
+ alibi_slopes: Optional[torch.Tensor] = None,
147
+ p_dropout: float = 0.0,
148
+ softmax_scale: float = 1.0,
149
+ is_causal: bool = False,
150
+ window_size_left: int = -1,
151
+ window_size_right: int = -1,
152
+ softcap: float = 0.0,
153
+ deterministic: bool = False,
154
+ gen: Optional[torch.Generator] = None,
155
+ rng_state: Optional[torch.Tensor] = None,
156
+ ) -> List[torch.Tensor]:
157
+ """
158
+ Backward pass for multi-head attention.
159
+
160
+ Args:
161
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
162
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
163
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
164
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
165
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
166
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
167
+ dq: Optional gradient tensor for queries, same shape as q
168
+ dk: Optional gradient tensor for keys, same shape as k
169
+ dv: Optional gradient tensor for values, same shape as v
170
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
171
+ p_dropout: Dropout probability
172
+ softmax_scale: Scale factor for softmax
173
+ is_causal: Whether to use causal attention
174
+ window_size_left: Window size for left context (-1 for unlimited)
175
+ window_size_right: Window size for right context (-1 for unlimited)
176
+ softcap: Soft cap for attention weights
177
+ deterministic: Whether to use deterministic algorithms
178
+ gen: Optional random number generator
179
+ rng_state: Optional RNG state from forward pass
180
+
181
+ Returns:
182
+ List of tensors: [dq, dk, dv]
183
+ """
184
+ return ops.mha_bwd(
185
+ dout,
186
+ q,
187
+ k,
188
+ v,
189
+ out,
190
+ softmax_lse,
191
+ dq,
192
+ dk,
193
+ dv,
194
+ alibi_slopes,
195
+ p_dropout,
196
+ softmax_scale,
197
+ is_causal,
198
+ window_size_left,
199
+ window_size_right,
200
+ softcap,
201
+ deterministic,
202
+ gen,
203
+ rng_state,
204
+ )
205
+
206
+
207
+ def mha_varlen_bwd(
208
+ dout: torch.Tensor,
209
+ q: torch.Tensor,
210
+ k: torch.Tensor,
211
+ v: torch.Tensor,
212
+ out: torch.Tensor,
213
+ softmax_lse: torch.Tensor,
214
+ cu_seqlens_q: torch.Tensor,
215
+ cu_seqlens_k: torch.Tensor,
216
+ dq: Optional[torch.Tensor] = None,
217
+ dk: Optional[torch.Tensor] = None,
218
+ dv: Optional[torch.Tensor] = None,
219
+ alibi_slopes: Optional[torch.Tensor] = None,
220
+ max_seqlen_q: int = 0,
221
+ max_seqlen_k: int = 0,
222
+ p_dropout: float = 0.0,
223
+ softmax_scale: float = 1.0,
224
+ zero_tensors: bool = False,
225
+ is_causal: bool = False,
226
+ window_size_left: int = -1,
227
+ window_size_right: int = -1,
228
+ softcap: float = 0.0,
229
+ deterministic: bool = False,
230
+ gen: Optional[torch.Generator] = None,
231
+ rng_state: Optional[torch.Tensor] = None,
232
+ ) -> List[torch.Tensor]:
233
+ """
234
+ Backward pass for multi-head attention with variable sequence lengths.
235
+
236
+ Args:
237
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
238
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
239
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
240
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
241
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
242
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
243
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
244
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
245
+ dq: Optional gradient tensor for queries, same shape as q
246
+ dk: Optional gradient tensor for keys, same shape as k
247
+ dv: Optional gradient tensor for values, same shape as v
248
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
249
+ max_seqlen_q: Maximum sequence length for queries
250
+ max_seqlen_k: Maximum sequence length for keys
251
+ p_dropout: Dropout probability
252
+ softmax_scale: Scale factor for softmax
253
+ zero_tensors: Whether to zero tensors before computation
254
+ is_causal: Whether to use causal attention
255
+ window_size_left: Window size for left context (-1 for unlimited)
256
+ window_size_right: Window size for right context (-1 for unlimited)
257
+ softcap: Soft cap for attention weights
258
+ deterministic: Whether to use deterministic algorithms
259
+ gen: Optional random number generator
260
+ rng_state: Optional RNG state from forward pass
261
+
262
+ Returns:
263
+ List of tensors: [dq, dk, dv]
264
+ """
265
+ return ops.mha_varlen_bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ dq,
273
+ dk,
274
+ dv,
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ alibi_slopes,
278
+ max_seqlen_q,
279
+ max_seqlen_k,
280
+ p_dropout,
281
+ softmax_scale,
282
+ zero_tensors,
283
+ is_causal,
284
+ window_size_left,
285
+ window_size_right,
286
+ softcap,
287
+ deterministic,
288
+ gen,
289
+ rng_state,
290
+ )
291
+
292
+
293
+ def mha_fwd_kvcache(
294
+ q: torch.Tensor,
295
+ kcache: torch.Tensor,
296
+ vcache: torch.Tensor,
297
+ k: Optional[torch.Tensor] = None,
298
+ v: Optional[torch.Tensor] = None,
299
+ seqlens_k: Optional[torch.Tensor] = None,
300
+ rotary_cos: Optional[torch.Tensor] = None,
301
+ rotary_sin: Optional[torch.Tensor] = None,
302
+ cache_batch_idx: Optional[torch.Tensor] = None,
303
+ leftpad_k: Optional[torch.Tensor] = None,
304
+ block_table: Optional[torch.Tensor] = None,
305
+ alibi_slopes: Optional[torch.Tensor] = None,
306
+ out: Optional[torch.Tensor] = None,
307
+ softmax_scale: float = 1.0,
308
+ is_causal: bool = False,
309
+ window_size_left: int = -1,
310
+ window_size_right: int = -1,
311
+ softcap: float = 0.0,
312
+ is_rotary_interleaved: bool = False,
313
+ num_splits: int = 1,
314
+ ) -> List[torch.Tensor]:
315
+ """
316
+ Forward pass for multi-head attention with KV cache.
317
+
318
+ Args:
319
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
320
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
321
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
322
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
323
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
324
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
325
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
326
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
327
+ cache_batch_idx: Optional indices to index into the KV cache
328
+ leftpad_k: Optional left padding for keys of shape [batch_size]
329
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
330
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
331
+ out: Optional output tensor, same shape as q
332
+ softmax_scale: Scale factor for softmax
333
+ is_causal: Whether to use causal attention
334
+ window_size_left: Window size for left context (-1 for unlimited)
335
+ window_size_right: Window size for right context (-1 for unlimited)
336
+ softcap: Soft cap for attention weights
337
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
338
+ num_splits: Number of splits for computation
339
+
340
+ Returns:
341
+ List of tensors: [output, softmax_lse]
342
+ """
343
+ return ops.mha_fwd_kvcache(
344
+ q,
345
+ kcache,
346
+ vcache,
347
+ k,
348
+ v,
349
+ seqlens_k,
350
+ rotary_cos,
351
+ rotary_sin,
352
+ cache_batch_idx,
353
+ leftpad_k,
354
+ block_table,
355
+ alibi_slopes,
356
+ out,
357
+ softmax_scale,
358
+ is_causal,
359
+ window_size_left,
360
+ window_size_right,
361
+ softcap,
362
+ is_rotary_interleaved,
363
+ num_splits,
364
+ )
build/torch26-cxx11-cu124-x86_64-linux/flash_attn/__init__.py CHANGED
@@ -1,25 +1,45 @@
1
- from typing import Optional
2
-
3
  import torch
4
-
5
  from ._ops import ops
6
 
 
7
  def mha_fwd(
8
  q: torch.Tensor,
9
  k: torch.Tensor,
10
  v: torch.Tensor,
11
- out: torch.Tensor,
12
- alibi_slopes: torch.Tensor,
13
- p_dropout: float,
14
- softmax_scale: float,
15
- is_causal: bool,
16
- window_size_left: int,
17
- window_size_right: int,
18
- softcap: float,
19
- return_softmax: bool,
20
- gen: Optional[torch.Generator],
21
- ) -> torch.Tensor:
22
- ops.mha_fwd(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  q,
24
  k,
25
  v,
@@ -34,4 +54,311 @@ def mha_fwd(
34
  return_softmax,
35
  gen,
36
  )
37
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
 
2
  import torch
 
3
  from ._ops import ops
4
 
5
+
6
  def mha_fwd(
7
  q: torch.Tensor,
8
  k: torch.Tensor,
9
  v: torch.Tensor,
10
+ out: Optional[torch.Tensor] = None,
11
+ alibi_slopes: Optional[torch.Tensor] = None,
12
+ p_dropout: float = 0.0,
13
+ softmax_scale: float = 1.0,
14
+ is_causal: bool = False,
15
+ window_size_left: int = -1,
16
+ window_size_right: int = -1,
17
+ softcap: float = 0.0,
18
+ return_softmax: bool = False,
19
+ gen: Optional[torch.Generator] = None,
20
+ ) -> List[torch.Tensor]:
21
+ """
22
+ Forward pass for multi-head attention.
23
+
24
+ Args:
25
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
26
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
27
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
28
+ out: Optional output tensor, same shape as q
29
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
30
+ p_dropout: Dropout probability
31
+ softmax_scale: Scale factor for softmax
32
+ is_causal: Whether to use causal attention
33
+ window_size_left: Window size for left context (-1 for unlimited)
34
+ window_size_right: Window size for right context (-1 for unlimited)
35
+ softcap: Soft cap for attention weights
36
+ return_softmax: Whether to return softmax weights
37
+ gen: Optional random number generator
38
+
39
+ Returns:
40
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
41
+ """
42
+ return ops.mha_fwd(
43
  q,
44
  k,
45
  v,
 
54
  return_softmax,
55
  gen,
56
  )
57
+
58
+
59
+ def mha_varlen_fwd(
60
+ q: torch.Tensor,
61
+ k: torch.Tensor,
62
+ v: torch.Tensor,
63
+ cu_seqlens_q: torch.Tensor,
64
+ cu_seqlens_k: torch.Tensor,
65
+ out: Optional[torch.Tensor] = None,
66
+ seqused_k: Optional[torch.Tensor] = None,
67
+ leftpad_k: Optional[torch.Tensor] = None,
68
+ block_table: Optional[torch.Tensor] = None,
69
+ alibi_slopes: Optional[torch.Tensor] = None,
70
+ max_seqlen_q: int = 0,
71
+ max_seqlen_k: int = 0,
72
+ p_dropout: float = 0.0,
73
+ softmax_scale: float = 1.0,
74
+ zero_tensors: bool = False,
75
+ is_causal: bool = False,
76
+ window_size_left: int = -1,
77
+ window_size_right: int = -1,
78
+ softcap: float = 0.0,
79
+ return_softmax: bool = False,
80
+ gen: Optional[torch.Generator] = None,
81
+ ) -> List[torch.Tensor]:
82
+ """
83
+ Forward pass for multi-head attention with variable sequence lengths.
84
+
85
+ Args:
86
+ q: Query tensor of shape [total_q, num_heads, head_size]
87
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
88
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
89
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
90
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
91
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
92
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
93
+ leftpad_k: Optional left padding for keys of shape [batch_size]
94
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
95
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
96
+ max_seqlen_q: Maximum sequence length for queries
97
+ max_seqlen_k: Maximum sequence length for keys
98
+ p_dropout: Dropout probability
99
+ softmax_scale: Scale factor for softmax
100
+ zero_tensors: Whether to zero tensors before computation
101
+ is_causal: Whether to use causal attention
102
+ window_size_left: Window size for left context (-1 for unlimited)
103
+ window_size_right: Window size for right context (-1 for unlimited)
104
+ softcap: Soft cap for attention weights
105
+ return_softmax: Whether to return softmax weights
106
+ gen: Optional random number generator
107
+
108
+ Returns:
109
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
110
+ """
111
+ return ops.mha_varlen_fwd(
112
+ q,
113
+ k,
114
+ v,
115
+ out,
116
+ cu_seqlens_q,
117
+ cu_seqlens_k,
118
+ seqused_k,
119
+ leftpad_k,
120
+ block_table,
121
+ alibi_slopes,
122
+ max_seqlen_q,
123
+ max_seqlen_k,
124
+ p_dropout,
125
+ softmax_scale,
126
+ zero_tensors,
127
+ is_causal,
128
+ window_size_left,
129
+ window_size_right,
130
+ softcap,
131
+ return_softmax,
132
+ gen,
133
+ )
134
+
135
+
136
+ def mha_bwd(
137
+ dout: torch.Tensor,
138
+ q: torch.Tensor,
139
+ k: torch.Tensor,
140
+ v: torch.Tensor,
141
+ out: torch.Tensor,
142
+ softmax_lse: torch.Tensor,
143
+ dq: Optional[torch.Tensor] = None,
144
+ dk: Optional[torch.Tensor] = None,
145
+ dv: Optional[torch.Tensor] = None,
146
+ alibi_slopes: Optional[torch.Tensor] = None,
147
+ p_dropout: float = 0.0,
148
+ softmax_scale: float = 1.0,
149
+ is_causal: bool = False,
150
+ window_size_left: int = -1,
151
+ window_size_right: int = -1,
152
+ softcap: float = 0.0,
153
+ deterministic: bool = False,
154
+ gen: Optional[torch.Generator] = None,
155
+ rng_state: Optional[torch.Tensor] = None,
156
+ ) -> List[torch.Tensor]:
157
+ """
158
+ Backward pass for multi-head attention.
159
+
160
+ Args:
161
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
162
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
163
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
164
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
165
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
166
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
167
+ dq: Optional gradient tensor for queries, same shape as q
168
+ dk: Optional gradient tensor for keys, same shape as k
169
+ dv: Optional gradient tensor for values, same shape as v
170
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
171
+ p_dropout: Dropout probability
172
+ softmax_scale: Scale factor for softmax
173
+ is_causal: Whether to use causal attention
174
+ window_size_left: Window size for left context (-1 for unlimited)
175
+ window_size_right: Window size for right context (-1 for unlimited)
176
+ softcap: Soft cap for attention weights
177
+ deterministic: Whether to use deterministic algorithms
178
+ gen: Optional random number generator
179
+ rng_state: Optional RNG state from forward pass
180
+
181
+ Returns:
182
+ List of tensors: [dq, dk, dv]
183
+ """
184
+ return ops.mha_bwd(
185
+ dout,
186
+ q,
187
+ k,
188
+ v,
189
+ out,
190
+ softmax_lse,
191
+ dq,
192
+ dk,
193
+ dv,
194
+ alibi_slopes,
195
+ p_dropout,
196
+ softmax_scale,
197
+ is_causal,
198
+ window_size_left,
199
+ window_size_right,
200
+ softcap,
201
+ deterministic,
202
+ gen,
203
+ rng_state,
204
+ )
205
+
206
+
207
+ def mha_varlen_bwd(
208
+ dout: torch.Tensor,
209
+ q: torch.Tensor,
210
+ k: torch.Tensor,
211
+ v: torch.Tensor,
212
+ out: torch.Tensor,
213
+ softmax_lse: torch.Tensor,
214
+ cu_seqlens_q: torch.Tensor,
215
+ cu_seqlens_k: torch.Tensor,
216
+ dq: Optional[torch.Tensor] = None,
217
+ dk: Optional[torch.Tensor] = None,
218
+ dv: Optional[torch.Tensor] = None,
219
+ alibi_slopes: Optional[torch.Tensor] = None,
220
+ max_seqlen_q: int = 0,
221
+ max_seqlen_k: int = 0,
222
+ p_dropout: float = 0.0,
223
+ softmax_scale: float = 1.0,
224
+ zero_tensors: bool = False,
225
+ is_causal: bool = False,
226
+ window_size_left: int = -1,
227
+ window_size_right: int = -1,
228
+ softcap: float = 0.0,
229
+ deterministic: bool = False,
230
+ gen: Optional[torch.Generator] = None,
231
+ rng_state: Optional[torch.Tensor] = None,
232
+ ) -> List[torch.Tensor]:
233
+ """
234
+ Backward pass for multi-head attention with variable sequence lengths.
235
+
236
+ Args:
237
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
238
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
239
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
240
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
241
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
242
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
243
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
244
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
245
+ dq: Optional gradient tensor for queries, same shape as q
246
+ dk: Optional gradient tensor for keys, same shape as k
247
+ dv: Optional gradient tensor for values, same shape as v
248
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
249
+ max_seqlen_q: Maximum sequence length for queries
250
+ max_seqlen_k: Maximum sequence length for keys
251
+ p_dropout: Dropout probability
252
+ softmax_scale: Scale factor for softmax
253
+ zero_tensors: Whether to zero tensors before computation
254
+ is_causal: Whether to use causal attention
255
+ window_size_left: Window size for left context (-1 for unlimited)
256
+ window_size_right: Window size for right context (-1 for unlimited)
257
+ softcap: Soft cap for attention weights
258
+ deterministic: Whether to use deterministic algorithms
259
+ gen: Optional random number generator
260
+ rng_state: Optional RNG state from forward pass
261
+
262
+ Returns:
263
+ List of tensors: [dq, dk, dv]
264
+ """
265
+ return ops.mha_varlen_bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ dq,
273
+ dk,
274
+ dv,
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ alibi_slopes,
278
+ max_seqlen_q,
279
+ max_seqlen_k,
280
+ p_dropout,
281
+ softmax_scale,
282
+ zero_tensors,
283
+ is_causal,
284
+ window_size_left,
285
+ window_size_right,
286
+ softcap,
287
+ deterministic,
288
+ gen,
289
+ rng_state,
290
+ )
291
+
292
+
293
+ def mha_fwd_kvcache(
294
+ q: torch.Tensor,
295
+ kcache: torch.Tensor,
296
+ vcache: torch.Tensor,
297
+ k: Optional[torch.Tensor] = None,
298
+ v: Optional[torch.Tensor] = None,
299
+ seqlens_k: Optional[torch.Tensor] = None,
300
+ rotary_cos: Optional[torch.Tensor] = None,
301
+ rotary_sin: Optional[torch.Tensor] = None,
302
+ cache_batch_idx: Optional[torch.Tensor] = None,
303
+ leftpad_k: Optional[torch.Tensor] = None,
304
+ block_table: Optional[torch.Tensor] = None,
305
+ alibi_slopes: Optional[torch.Tensor] = None,
306
+ out: Optional[torch.Tensor] = None,
307
+ softmax_scale: float = 1.0,
308
+ is_causal: bool = False,
309
+ window_size_left: int = -1,
310
+ window_size_right: int = -1,
311
+ softcap: float = 0.0,
312
+ is_rotary_interleaved: bool = False,
313
+ num_splits: int = 1,
314
+ ) -> List[torch.Tensor]:
315
+ """
316
+ Forward pass for multi-head attention with KV cache.
317
+
318
+ Args:
319
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
320
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
321
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
322
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
323
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
324
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
325
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
326
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
327
+ cache_batch_idx: Optional indices to index into the KV cache
328
+ leftpad_k: Optional left padding for keys of shape [batch_size]
329
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
330
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
331
+ out: Optional output tensor, same shape as q
332
+ softmax_scale: Scale factor for softmax
333
+ is_causal: Whether to use causal attention
334
+ window_size_left: Window size for left context (-1 for unlimited)
335
+ window_size_right: Window size for right context (-1 for unlimited)
336
+ softcap: Soft cap for attention weights
337
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
338
+ num_splits: Number of splits for computation
339
+
340
+ Returns:
341
+ List of tensors: [output, softmax_lse]
342
+ """
343
+ return ops.mha_fwd_kvcache(
344
+ q,
345
+ kcache,
346
+ vcache,
347
+ k,
348
+ v,
349
+ seqlens_k,
350
+ rotary_cos,
351
+ rotary_sin,
352
+ cache_batch_idx,
353
+ leftpad_k,
354
+ block_table,
355
+ alibi_slopes,
356
+ out,
357
+ softmax_scale,
358
+ is_causal,
359
+ window_size_left,
360
+ window_size_right,
361
+ softcap,
362
+ is_rotary_interleaved,
363
+ num_splits,
364
+ )
build/torch26-cxx11-cu126-x86_64-linux/flash_attn/__init__.py CHANGED
@@ -1,25 +1,45 @@
1
- from typing import Optional
2
-
3
  import torch
4
-
5
  from ._ops import ops
6
 
 
7
  def mha_fwd(
8
  q: torch.Tensor,
9
  k: torch.Tensor,
10
  v: torch.Tensor,
11
- out: torch.Tensor,
12
- alibi_slopes: torch.Tensor,
13
- p_dropout: float,
14
- softmax_scale: float,
15
- is_causal: bool,
16
- window_size_left: int,
17
- window_size_right: int,
18
- softcap: float,
19
- return_softmax: bool,
20
- gen: Optional[torch.Generator],
21
- ) -> torch.Tensor:
22
- ops.mha_fwd(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  q,
24
  k,
25
  v,
@@ -34,4 +54,311 @@ def mha_fwd(
34
  return_softmax,
35
  gen,
36
  )
37
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
 
2
  import torch
 
3
  from ._ops import ops
4
 
5
+
6
  def mha_fwd(
7
  q: torch.Tensor,
8
  k: torch.Tensor,
9
  v: torch.Tensor,
10
+ out: Optional[torch.Tensor] = None,
11
+ alibi_slopes: Optional[torch.Tensor] = None,
12
+ p_dropout: float = 0.0,
13
+ softmax_scale: float = 1.0,
14
+ is_causal: bool = False,
15
+ window_size_left: int = -1,
16
+ window_size_right: int = -1,
17
+ softcap: float = 0.0,
18
+ return_softmax: bool = False,
19
+ gen: Optional[torch.Generator] = None,
20
+ ) -> List[torch.Tensor]:
21
+ """
22
+ Forward pass for multi-head attention.
23
+
24
+ Args:
25
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
26
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
27
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
28
+ out: Optional output tensor, same shape as q
29
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
30
+ p_dropout: Dropout probability
31
+ softmax_scale: Scale factor for softmax
32
+ is_causal: Whether to use causal attention
33
+ window_size_left: Window size for left context (-1 for unlimited)
34
+ window_size_right: Window size for right context (-1 for unlimited)
35
+ softcap: Soft cap for attention weights
36
+ return_softmax: Whether to return softmax weights
37
+ gen: Optional random number generator
38
+
39
+ Returns:
40
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
41
+ """
42
+ return ops.mha_fwd(
43
  q,
44
  k,
45
  v,
 
54
  return_softmax,
55
  gen,
56
  )
57
+
58
+
59
+ def mha_varlen_fwd(
60
+ q: torch.Tensor,
61
+ k: torch.Tensor,
62
+ v: torch.Tensor,
63
+ cu_seqlens_q: torch.Tensor,
64
+ cu_seqlens_k: torch.Tensor,
65
+ out: Optional[torch.Tensor] = None,
66
+ seqused_k: Optional[torch.Tensor] = None,
67
+ leftpad_k: Optional[torch.Tensor] = None,
68
+ block_table: Optional[torch.Tensor] = None,
69
+ alibi_slopes: Optional[torch.Tensor] = None,
70
+ max_seqlen_q: int = 0,
71
+ max_seqlen_k: int = 0,
72
+ p_dropout: float = 0.0,
73
+ softmax_scale: float = 1.0,
74
+ zero_tensors: bool = False,
75
+ is_causal: bool = False,
76
+ window_size_left: int = -1,
77
+ window_size_right: int = -1,
78
+ softcap: float = 0.0,
79
+ return_softmax: bool = False,
80
+ gen: Optional[torch.Generator] = None,
81
+ ) -> List[torch.Tensor]:
82
+ """
83
+ Forward pass for multi-head attention with variable sequence lengths.
84
+
85
+ Args:
86
+ q: Query tensor of shape [total_q, num_heads, head_size]
87
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
88
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
89
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
90
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
91
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
92
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
93
+ leftpad_k: Optional left padding for keys of shape [batch_size]
94
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
95
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
96
+ max_seqlen_q: Maximum sequence length for queries
97
+ max_seqlen_k: Maximum sequence length for keys
98
+ p_dropout: Dropout probability
99
+ softmax_scale: Scale factor for softmax
100
+ zero_tensors: Whether to zero tensors before computation
101
+ is_causal: Whether to use causal attention
102
+ window_size_left: Window size for left context (-1 for unlimited)
103
+ window_size_right: Window size for right context (-1 for unlimited)
104
+ softcap: Soft cap for attention weights
105
+ return_softmax: Whether to return softmax weights
106
+ gen: Optional random number generator
107
+
108
+ Returns:
109
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
110
+ """
111
+ return ops.mha_varlen_fwd(
112
+ q,
113
+ k,
114
+ v,
115
+ out,
116
+ cu_seqlens_q,
117
+ cu_seqlens_k,
118
+ seqused_k,
119
+ leftpad_k,
120
+ block_table,
121
+ alibi_slopes,
122
+ max_seqlen_q,
123
+ max_seqlen_k,
124
+ p_dropout,
125
+ softmax_scale,
126
+ zero_tensors,
127
+ is_causal,
128
+ window_size_left,
129
+ window_size_right,
130
+ softcap,
131
+ return_softmax,
132
+ gen,
133
+ )
134
+
135
+
136
+ def mha_bwd(
137
+ dout: torch.Tensor,
138
+ q: torch.Tensor,
139
+ k: torch.Tensor,
140
+ v: torch.Tensor,
141
+ out: torch.Tensor,
142
+ softmax_lse: torch.Tensor,
143
+ dq: Optional[torch.Tensor] = None,
144
+ dk: Optional[torch.Tensor] = None,
145
+ dv: Optional[torch.Tensor] = None,
146
+ alibi_slopes: Optional[torch.Tensor] = None,
147
+ p_dropout: float = 0.0,
148
+ softmax_scale: float = 1.0,
149
+ is_causal: bool = False,
150
+ window_size_left: int = -1,
151
+ window_size_right: int = -1,
152
+ softcap: float = 0.0,
153
+ deterministic: bool = False,
154
+ gen: Optional[torch.Generator] = None,
155
+ rng_state: Optional[torch.Tensor] = None,
156
+ ) -> List[torch.Tensor]:
157
+ """
158
+ Backward pass for multi-head attention.
159
+
160
+ Args:
161
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
162
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
163
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
164
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
165
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
166
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
167
+ dq: Optional gradient tensor for queries, same shape as q
168
+ dk: Optional gradient tensor for keys, same shape as k
169
+ dv: Optional gradient tensor for values, same shape as v
170
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
171
+ p_dropout: Dropout probability
172
+ softmax_scale: Scale factor for softmax
173
+ is_causal: Whether to use causal attention
174
+ window_size_left: Window size for left context (-1 for unlimited)
175
+ window_size_right: Window size for right context (-1 for unlimited)
176
+ softcap: Soft cap for attention weights
177
+ deterministic: Whether to use deterministic algorithms
178
+ gen: Optional random number generator
179
+ rng_state: Optional RNG state from forward pass
180
+
181
+ Returns:
182
+ List of tensors: [dq, dk, dv]
183
+ """
184
+ return ops.mha_bwd(
185
+ dout,
186
+ q,
187
+ k,
188
+ v,
189
+ out,
190
+ softmax_lse,
191
+ dq,
192
+ dk,
193
+ dv,
194
+ alibi_slopes,
195
+ p_dropout,
196
+ softmax_scale,
197
+ is_causal,
198
+ window_size_left,
199
+ window_size_right,
200
+ softcap,
201
+ deterministic,
202
+ gen,
203
+ rng_state,
204
+ )
205
+
206
+
207
+ def mha_varlen_bwd(
208
+ dout: torch.Tensor,
209
+ q: torch.Tensor,
210
+ k: torch.Tensor,
211
+ v: torch.Tensor,
212
+ out: torch.Tensor,
213
+ softmax_lse: torch.Tensor,
214
+ cu_seqlens_q: torch.Tensor,
215
+ cu_seqlens_k: torch.Tensor,
216
+ dq: Optional[torch.Tensor] = None,
217
+ dk: Optional[torch.Tensor] = None,
218
+ dv: Optional[torch.Tensor] = None,
219
+ alibi_slopes: Optional[torch.Tensor] = None,
220
+ max_seqlen_q: int = 0,
221
+ max_seqlen_k: int = 0,
222
+ p_dropout: float = 0.0,
223
+ softmax_scale: float = 1.0,
224
+ zero_tensors: bool = False,
225
+ is_causal: bool = False,
226
+ window_size_left: int = -1,
227
+ window_size_right: int = -1,
228
+ softcap: float = 0.0,
229
+ deterministic: bool = False,
230
+ gen: Optional[torch.Generator] = None,
231
+ rng_state: Optional[torch.Tensor] = None,
232
+ ) -> List[torch.Tensor]:
233
+ """
234
+ Backward pass for multi-head attention with variable sequence lengths.
235
+
236
+ Args:
237
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
238
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
239
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
240
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
241
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
242
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
243
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
244
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
245
+ dq: Optional gradient tensor for queries, same shape as q
246
+ dk: Optional gradient tensor for keys, same shape as k
247
+ dv: Optional gradient tensor for values, same shape as v
248
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
249
+ max_seqlen_q: Maximum sequence length for queries
250
+ max_seqlen_k: Maximum sequence length for keys
251
+ p_dropout: Dropout probability
252
+ softmax_scale: Scale factor for softmax
253
+ zero_tensors: Whether to zero tensors before computation
254
+ is_causal: Whether to use causal attention
255
+ window_size_left: Window size for left context (-1 for unlimited)
256
+ window_size_right: Window size for right context (-1 for unlimited)
257
+ softcap: Soft cap for attention weights
258
+ deterministic: Whether to use deterministic algorithms
259
+ gen: Optional random number generator
260
+ rng_state: Optional RNG state from forward pass
261
+
262
+ Returns:
263
+ List of tensors: [dq, dk, dv]
264
+ """
265
+ return ops.mha_varlen_bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ dq,
273
+ dk,
274
+ dv,
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ alibi_slopes,
278
+ max_seqlen_q,
279
+ max_seqlen_k,
280
+ p_dropout,
281
+ softmax_scale,
282
+ zero_tensors,
283
+ is_causal,
284
+ window_size_left,
285
+ window_size_right,
286
+ softcap,
287
+ deterministic,
288
+ gen,
289
+ rng_state,
290
+ )
291
+
292
+
293
+ def mha_fwd_kvcache(
294
+ q: torch.Tensor,
295
+ kcache: torch.Tensor,
296
+ vcache: torch.Tensor,
297
+ k: Optional[torch.Tensor] = None,
298
+ v: Optional[torch.Tensor] = None,
299
+ seqlens_k: Optional[torch.Tensor] = None,
300
+ rotary_cos: Optional[torch.Tensor] = None,
301
+ rotary_sin: Optional[torch.Tensor] = None,
302
+ cache_batch_idx: Optional[torch.Tensor] = None,
303
+ leftpad_k: Optional[torch.Tensor] = None,
304
+ block_table: Optional[torch.Tensor] = None,
305
+ alibi_slopes: Optional[torch.Tensor] = None,
306
+ out: Optional[torch.Tensor] = None,
307
+ softmax_scale: float = 1.0,
308
+ is_causal: bool = False,
309
+ window_size_left: int = -1,
310
+ window_size_right: int = -1,
311
+ softcap: float = 0.0,
312
+ is_rotary_interleaved: bool = False,
313
+ num_splits: int = 1,
314
+ ) -> List[torch.Tensor]:
315
+ """
316
+ Forward pass for multi-head attention with KV cache.
317
+
318
+ Args:
319
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
320
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
321
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
322
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
323
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
324
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
325
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
326
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
327
+ cache_batch_idx: Optional indices to index into the KV cache
328
+ leftpad_k: Optional left padding for keys of shape [batch_size]
329
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
330
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
331
+ out: Optional output tensor, same shape as q
332
+ softmax_scale: Scale factor for softmax
333
+ is_causal: Whether to use causal attention
334
+ window_size_left: Window size for left context (-1 for unlimited)
335
+ window_size_right: Window size for right context (-1 for unlimited)
336
+ softcap: Soft cap for attention weights
337
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
338
+ num_splits: Number of splits for computation
339
+
340
+ Returns:
341
+ List of tensors: [output, softmax_lse]
342
+ """
343
+ return ops.mha_fwd_kvcache(
344
+ q,
345
+ kcache,
346
+ vcache,
347
+ k,
348
+ v,
349
+ seqlens_k,
350
+ rotary_cos,
351
+ rotary_sin,
352
+ cache_batch_idx,
353
+ leftpad_k,
354
+ block_table,
355
+ alibi_slopes,
356
+ out,
357
+ softmax_scale,
358
+ is_causal,
359
+ window_size_left,
360
+ window_size_right,
361
+ softcap,
362
+ is_rotary_interleaved,
363
+ num_splits,
364
+ )
build/torch26-cxx98-cu118-x86_64-linux/flash_attn/__init__.py CHANGED
@@ -1,25 +1,45 @@
1
- from typing import Optional
2
-
3
  import torch
4
-
5
  from ._ops import ops
6
 
 
7
  def mha_fwd(
8
  q: torch.Tensor,
9
  k: torch.Tensor,
10
  v: torch.Tensor,
11
- out: torch.Tensor,
12
- alibi_slopes: torch.Tensor,
13
- p_dropout: float,
14
- softmax_scale: float,
15
- is_causal: bool,
16
- window_size_left: int,
17
- window_size_right: int,
18
- softcap: float,
19
- return_softmax: bool,
20
- gen: Optional[torch.Generator],
21
- ) -> torch.Tensor:
22
- ops.mha_fwd(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  q,
24
  k,
25
  v,
@@ -34,4 +54,311 @@ def mha_fwd(
34
  return_softmax,
35
  gen,
36
  )
37
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
 
2
  import torch
 
3
  from ._ops import ops
4
 
5
+
6
  def mha_fwd(
7
  q: torch.Tensor,
8
  k: torch.Tensor,
9
  v: torch.Tensor,
10
+ out: Optional[torch.Tensor] = None,
11
+ alibi_slopes: Optional[torch.Tensor] = None,
12
+ p_dropout: float = 0.0,
13
+ softmax_scale: float = 1.0,
14
+ is_causal: bool = False,
15
+ window_size_left: int = -1,
16
+ window_size_right: int = -1,
17
+ softcap: float = 0.0,
18
+ return_softmax: bool = False,
19
+ gen: Optional[torch.Generator] = None,
20
+ ) -> List[torch.Tensor]:
21
+ """
22
+ Forward pass for multi-head attention.
23
+
24
+ Args:
25
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
26
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
27
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
28
+ out: Optional output tensor, same shape as q
29
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
30
+ p_dropout: Dropout probability
31
+ softmax_scale: Scale factor for softmax
32
+ is_causal: Whether to use causal attention
33
+ window_size_left: Window size for left context (-1 for unlimited)
34
+ window_size_right: Window size for right context (-1 for unlimited)
35
+ softcap: Soft cap for attention weights
36
+ return_softmax: Whether to return softmax weights
37
+ gen: Optional random number generator
38
+
39
+ Returns:
40
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
41
+ """
42
+ return ops.mha_fwd(
43
  q,
44
  k,
45
  v,
 
54
  return_softmax,
55
  gen,
56
  )
57
+
58
+
59
+ def mha_varlen_fwd(
60
+ q: torch.Tensor,
61
+ k: torch.Tensor,
62
+ v: torch.Tensor,
63
+ cu_seqlens_q: torch.Tensor,
64
+ cu_seqlens_k: torch.Tensor,
65
+ out: Optional[torch.Tensor] = None,
66
+ seqused_k: Optional[torch.Tensor] = None,
67
+ leftpad_k: Optional[torch.Tensor] = None,
68
+ block_table: Optional[torch.Tensor] = None,
69
+ alibi_slopes: Optional[torch.Tensor] = None,
70
+ max_seqlen_q: int = 0,
71
+ max_seqlen_k: int = 0,
72
+ p_dropout: float = 0.0,
73
+ softmax_scale: float = 1.0,
74
+ zero_tensors: bool = False,
75
+ is_causal: bool = False,
76
+ window_size_left: int = -1,
77
+ window_size_right: int = -1,
78
+ softcap: float = 0.0,
79
+ return_softmax: bool = False,
80
+ gen: Optional[torch.Generator] = None,
81
+ ) -> List[torch.Tensor]:
82
+ """
83
+ Forward pass for multi-head attention with variable sequence lengths.
84
+
85
+ Args:
86
+ q: Query tensor of shape [total_q, num_heads, head_size]
87
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
88
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
89
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
90
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
91
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
92
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
93
+ leftpad_k: Optional left padding for keys of shape [batch_size]
94
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
95
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
96
+ max_seqlen_q: Maximum sequence length for queries
97
+ max_seqlen_k: Maximum sequence length for keys
98
+ p_dropout: Dropout probability
99
+ softmax_scale: Scale factor for softmax
100
+ zero_tensors: Whether to zero tensors before computation
101
+ is_causal: Whether to use causal attention
102
+ window_size_left: Window size for left context (-1 for unlimited)
103
+ window_size_right: Window size for right context (-1 for unlimited)
104
+ softcap: Soft cap for attention weights
105
+ return_softmax: Whether to return softmax weights
106
+ gen: Optional random number generator
107
+
108
+ Returns:
109
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
110
+ """
111
+ return ops.mha_varlen_fwd(
112
+ q,
113
+ k,
114
+ v,
115
+ out,
116
+ cu_seqlens_q,
117
+ cu_seqlens_k,
118
+ seqused_k,
119
+ leftpad_k,
120
+ block_table,
121
+ alibi_slopes,
122
+ max_seqlen_q,
123
+ max_seqlen_k,
124
+ p_dropout,
125
+ softmax_scale,
126
+ zero_tensors,
127
+ is_causal,
128
+ window_size_left,
129
+ window_size_right,
130
+ softcap,
131
+ return_softmax,
132
+ gen,
133
+ )
134
+
135
+
136
+ def mha_bwd(
137
+ dout: torch.Tensor,
138
+ q: torch.Tensor,
139
+ k: torch.Tensor,
140
+ v: torch.Tensor,
141
+ out: torch.Tensor,
142
+ softmax_lse: torch.Tensor,
143
+ dq: Optional[torch.Tensor] = None,
144
+ dk: Optional[torch.Tensor] = None,
145
+ dv: Optional[torch.Tensor] = None,
146
+ alibi_slopes: Optional[torch.Tensor] = None,
147
+ p_dropout: float = 0.0,
148
+ softmax_scale: float = 1.0,
149
+ is_causal: bool = False,
150
+ window_size_left: int = -1,
151
+ window_size_right: int = -1,
152
+ softcap: float = 0.0,
153
+ deterministic: bool = False,
154
+ gen: Optional[torch.Generator] = None,
155
+ rng_state: Optional[torch.Tensor] = None,
156
+ ) -> List[torch.Tensor]:
157
+ """
158
+ Backward pass for multi-head attention.
159
+
160
+ Args:
161
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
162
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
163
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
164
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
165
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
166
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
167
+ dq: Optional gradient tensor for queries, same shape as q
168
+ dk: Optional gradient tensor for keys, same shape as k
169
+ dv: Optional gradient tensor for values, same shape as v
170
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
171
+ p_dropout: Dropout probability
172
+ softmax_scale: Scale factor for softmax
173
+ is_causal: Whether to use causal attention
174
+ window_size_left: Window size for left context (-1 for unlimited)
175
+ window_size_right: Window size for right context (-1 for unlimited)
176
+ softcap: Soft cap for attention weights
177
+ deterministic: Whether to use deterministic algorithms
178
+ gen: Optional random number generator
179
+ rng_state: Optional RNG state from forward pass
180
+
181
+ Returns:
182
+ List of tensors: [dq, dk, dv]
183
+ """
184
+ return ops.mha_bwd(
185
+ dout,
186
+ q,
187
+ k,
188
+ v,
189
+ out,
190
+ softmax_lse,
191
+ dq,
192
+ dk,
193
+ dv,
194
+ alibi_slopes,
195
+ p_dropout,
196
+ softmax_scale,
197
+ is_causal,
198
+ window_size_left,
199
+ window_size_right,
200
+ softcap,
201
+ deterministic,
202
+ gen,
203
+ rng_state,
204
+ )
205
+
206
+
207
+ def mha_varlen_bwd(
208
+ dout: torch.Tensor,
209
+ q: torch.Tensor,
210
+ k: torch.Tensor,
211
+ v: torch.Tensor,
212
+ out: torch.Tensor,
213
+ softmax_lse: torch.Tensor,
214
+ cu_seqlens_q: torch.Tensor,
215
+ cu_seqlens_k: torch.Tensor,
216
+ dq: Optional[torch.Tensor] = None,
217
+ dk: Optional[torch.Tensor] = None,
218
+ dv: Optional[torch.Tensor] = None,
219
+ alibi_slopes: Optional[torch.Tensor] = None,
220
+ max_seqlen_q: int = 0,
221
+ max_seqlen_k: int = 0,
222
+ p_dropout: float = 0.0,
223
+ softmax_scale: float = 1.0,
224
+ zero_tensors: bool = False,
225
+ is_causal: bool = False,
226
+ window_size_left: int = -1,
227
+ window_size_right: int = -1,
228
+ softcap: float = 0.0,
229
+ deterministic: bool = False,
230
+ gen: Optional[torch.Generator] = None,
231
+ rng_state: Optional[torch.Tensor] = None,
232
+ ) -> List[torch.Tensor]:
233
+ """
234
+ Backward pass for multi-head attention with variable sequence lengths.
235
+
236
+ Args:
237
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
238
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
239
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
240
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
241
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
242
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
243
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
244
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
245
+ dq: Optional gradient tensor for queries, same shape as q
246
+ dk: Optional gradient tensor for keys, same shape as k
247
+ dv: Optional gradient tensor for values, same shape as v
248
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
249
+ max_seqlen_q: Maximum sequence length for queries
250
+ max_seqlen_k: Maximum sequence length for keys
251
+ p_dropout: Dropout probability
252
+ softmax_scale: Scale factor for softmax
253
+ zero_tensors: Whether to zero tensors before computation
254
+ is_causal: Whether to use causal attention
255
+ window_size_left: Window size for left context (-1 for unlimited)
256
+ window_size_right: Window size for right context (-1 for unlimited)
257
+ softcap: Soft cap for attention weights
258
+ deterministic: Whether to use deterministic algorithms
259
+ gen: Optional random number generator
260
+ rng_state: Optional RNG state from forward pass
261
+
262
+ Returns:
263
+ List of tensors: [dq, dk, dv]
264
+ """
265
+ return ops.mha_varlen_bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ dq,
273
+ dk,
274
+ dv,
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ alibi_slopes,
278
+ max_seqlen_q,
279
+ max_seqlen_k,
280
+ p_dropout,
281
+ softmax_scale,
282
+ zero_tensors,
283
+ is_causal,
284
+ window_size_left,
285
+ window_size_right,
286
+ softcap,
287
+ deterministic,
288
+ gen,
289
+ rng_state,
290
+ )
291
+
292
+
293
+ def mha_fwd_kvcache(
294
+ q: torch.Tensor,
295
+ kcache: torch.Tensor,
296
+ vcache: torch.Tensor,
297
+ k: Optional[torch.Tensor] = None,
298
+ v: Optional[torch.Tensor] = None,
299
+ seqlens_k: Optional[torch.Tensor] = None,
300
+ rotary_cos: Optional[torch.Tensor] = None,
301
+ rotary_sin: Optional[torch.Tensor] = None,
302
+ cache_batch_idx: Optional[torch.Tensor] = None,
303
+ leftpad_k: Optional[torch.Tensor] = None,
304
+ block_table: Optional[torch.Tensor] = None,
305
+ alibi_slopes: Optional[torch.Tensor] = None,
306
+ out: Optional[torch.Tensor] = None,
307
+ softmax_scale: float = 1.0,
308
+ is_causal: bool = False,
309
+ window_size_left: int = -1,
310
+ window_size_right: int = -1,
311
+ softcap: float = 0.0,
312
+ is_rotary_interleaved: bool = False,
313
+ num_splits: int = 1,
314
+ ) -> List[torch.Tensor]:
315
+ """
316
+ Forward pass for multi-head attention with KV cache.
317
+
318
+ Args:
319
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
320
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
321
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
322
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
323
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
324
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
325
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
326
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
327
+ cache_batch_idx: Optional indices to index into the KV cache
328
+ leftpad_k: Optional left padding for keys of shape [batch_size]
329
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
330
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
331
+ out: Optional output tensor, same shape as q
332
+ softmax_scale: Scale factor for softmax
333
+ is_causal: Whether to use causal attention
334
+ window_size_left: Window size for left context (-1 for unlimited)
335
+ window_size_right: Window size for right context (-1 for unlimited)
336
+ softcap: Soft cap for attention weights
337
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
338
+ num_splits: Number of splits for computation
339
+
340
+ Returns:
341
+ List of tensors: [output, softmax_lse]
342
+ """
343
+ return ops.mha_fwd_kvcache(
344
+ q,
345
+ kcache,
346
+ vcache,
347
+ k,
348
+ v,
349
+ seqlens_k,
350
+ rotary_cos,
351
+ rotary_sin,
352
+ cache_batch_idx,
353
+ leftpad_k,
354
+ block_table,
355
+ alibi_slopes,
356
+ out,
357
+ softmax_scale,
358
+ is_causal,
359
+ window_size_left,
360
+ window_size_right,
361
+ softcap,
362
+ is_rotary_interleaved,
363
+ num_splits,
364
+ )
build/torch26-cxx98-cu126-x86_64-linux/flash_attn/__init__.py CHANGED
@@ -1,25 +1,45 @@
1
- from typing import Optional
2
-
3
  import torch
4
-
5
  from ._ops import ops
6
 
 
7
  def mha_fwd(
8
  q: torch.Tensor,
9
  k: torch.Tensor,
10
  v: torch.Tensor,
11
- out: torch.Tensor,
12
- alibi_slopes: torch.Tensor,
13
- p_dropout: float,
14
- softmax_scale: float,
15
- is_causal: bool,
16
- window_size_left: int,
17
- window_size_right: int,
18
- softcap: float,
19
- return_softmax: bool,
20
- gen: Optional[torch.Generator],
21
- ) -> torch.Tensor:
22
- ops.mha_fwd(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  q,
24
  k,
25
  v,
@@ -34,4 +54,311 @@ def mha_fwd(
34
  return_softmax,
35
  gen,
36
  )
37
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
 
2
  import torch
 
3
  from ._ops import ops
4
 
5
+
6
  def mha_fwd(
7
  q: torch.Tensor,
8
  k: torch.Tensor,
9
  v: torch.Tensor,
10
+ out: Optional[torch.Tensor] = None,
11
+ alibi_slopes: Optional[torch.Tensor] = None,
12
+ p_dropout: float = 0.0,
13
+ softmax_scale: float = 1.0,
14
+ is_causal: bool = False,
15
+ window_size_left: int = -1,
16
+ window_size_right: int = -1,
17
+ softcap: float = 0.0,
18
+ return_softmax: bool = False,
19
+ gen: Optional[torch.Generator] = None,
20
+ ) -> List[torch.Tensor]:
21
+ """
22
+ Forward pass for multi-head attention.
23
+
24
+ Args:
25
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
26
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
27
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
28
+ out: Optional output tensor, same shape as q
29
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
30
+ p_dropout: Dropout probability
31
+ softmax_scale: Scale factor for softmax
32
+ is_causal: Whether to use causal attention
33
+ window_size_left: Window size for left context (-1 for unlimited)
34
+ window_size_right: Window size for right context (-1 for unlimited)
35
+ softcap: Soft cap for attention weights
36
+ return_softmax: Whether to return softmax weights
37
+ gen: Optional random number generator
38
+
39
+ Returns:
40
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
41
+ """
42
+ return ops.mha_fwd(
43
  q,
44
  k,
45
  v,
 
54
  return_softmax,
55
  gen,
56
  )
57
+
58
+
59
+ def mha_varlen_fwd(
60
+ q: torch.Tensor,
61
+ k: torch.Tensor,
62
+ v: torch.Tensor,
63
+ cu_seqlens_q: torch.Tensor,
64
+ cu_seqlens_k: torch.Tensor,
65
+ out: Optional[torch.Tensor] = None,
66
+ seqused_k: Optional[torch.Tensor] = None,
67
+ leftpad_k: Optional[torch.Tensor] = None,
68
+ block_table: Optional[torch.Tensor] = None,
69
+ alibi_slopes: Optional[torch.Tensor] = None,
70
+ max_seqlen_q: int = 0,
71
+ max_seqlen_k: int = 0,
72
+ p_dropout: float = 0.0,
73
+ softmax_scale: float = 1.0,
74
+ zero_tensors: bool = False,
75
+ is_causal: bool = False,
76
+ window_size_left: int = -1,
77
+ window_size_right: int = -1,
78
+ softcap: float = 0.0,
79
+ return_softmax: bool = False,
80
+ gen: Optional[torch.Generator] = None,
81
+ ) -> List[torch.Tensor]:
82
+ """
83
+ Forward pass for multi-head attention with variable sequence lengths.
84
+
85
+ Args:
86
+ q: Query tensor of shape [total_q, num_heads, head_size]
87
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
88
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
89
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
90
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
91
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
92
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
93
+ leftpad_k: Optional left padding for keys of shape [batch_size]
94
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
95
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
96
+ max_seqlen_q: Maximum sequence length for queries
97
+ max_seqlen_k: Maximum sequence length for keys
98
+ p_dropout: Dropout probability
99
+ softmax_scale: Scale factor for softmax
100
+ zero_tensors: Whether to zero tensors before computation
101
+ is_causal: Whether to use causal attention
102
+ window_size_left: Window size for left context (-1 for unlimited)
103
+ window_size_right: Window size for right context (-1 for unlimited)
104
+ softcap: Soft cap for attention weights
105
+ return_softmax: Whether to return softmax weights
106
+ gen: Optional random number generator
107
+
108
+ Returns:
109
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
110
+ """
111
+ return ops.mha_varlen_fwd(
112
+ q,
113
+ k,
114
+ v,
115
+ out,
116
+ cu_seqlens_q,
117
+ cu_seqlens_k,
118
+ seqused_k,
119
+ leftpad_k,
120
+ block_table,
121
+ alibi_slopes,
122
+ max_seqlen_q,
123
+ max_seqlen_k,
124
+ p_dropout,
125
+ softmax_scale,
126
+ zero_tensors,
127
+ is_causal,
128
+ window_size_left,
129
+ window_size_right,
130
+ softcap,
131
+ return_softmax,
132
+ gen,
133
+ )
134
+
135
+
136
+ def mha_bwd(
137
+ dout: torch.Tensor,
138
+ q: torch.Tensor,
139
+ k: torch.Tensor,
140
+ v: torch.Tensor,
141
+ out: torch.Tensor,
142
+ softmax_lse: torch.Tensor,
143
+ dq: Optional[torch.Tensor] = None,
144
+ dk: Optional[torch.Tensor] = None,
145
+ dv: Optional[torch.Tensor] = None,
146
+ alibi_slopes: Optional[torch.Tensor] = None,
147
+ p_dropout: float = 0.0,
148
+ softmax_scale: float = 1.0,
149
+ is_causal: bool = False,
150
+ window_size_left: int = -1,
151
+ window_size_right: int = -1,
152
+ softcap: float = 0.0,
153
+ deterministic: bool = False,
154
+ gen: Optional[torch.Generator] = None,
155
+ rng_state: Optional[torch.Tensor] = None,
156
+ ) -> List[torch.Tensor]:
157
+ """
158
+ Backward pass for multi-head attention.
159
+
160
+ Args:
161
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
162
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
163
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
164
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
165
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
166
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
167
+ dq: Optional gradient tensor for queries, same shape as q
168
+ dk: Optional gradient tensor for keys, same shape as k
169
+ dv: Optional gradient tensor for values, same shape as v
170
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
171
+ p_dropout: Dropout probability
172
+ softmax_scale: Scale factor for softmax
173
+ is_causal: Whether to use causal attention
174
+ window_size_left: Window size for left context (-1 for unlimited)
175
+ window_size_right: Window size for right context (-1 for unlimited)
176
+ softcap: Soft cap for attention weights
177
+ deterministic: Whether to use deterministic algorithms
178
+ gen: Optional random number generator
179
+ rng_state: Optional RNG state from forward pass
180
+
181
+ Returns:
182
+ List of tensors: [dq, dk, dv]
183
+ """
184
+ return ops.mha_bwd(
185
+ dout,
186
+ q,
187
+ k,
188
+ v,
189
+ out,
190
+ softmax_lse,
191
+ dq,
192
+ dk,
193
+ dv,
194
+ alibi_slopes,
195
+ p_dropout,
196
+ softmax_scale,
197
+ is_causal,
198
+ window_size_left,
199
+ window_size_right,
200
+ softcap,
201
+ deterministic,
202
+ gen,
203
+ rng_state,
204
+ )
205
+
206
+
207
+ def mha_varlen_bwd(
208
+ dout: torch.Tensor,
209
+ q: torch.Tensor,
210
+ k: torch.Tensor,
211
+ v: torch.Tensor,
212
+ out: torch.Tensor,
213
+ softmax_lse: torch.Tensor,
214
+ cu_seqlens_q: torch.Tensor,
215
+ cu_seqlens_k: torch.Tensor,
216
+ dq: Optional[torch.Tensor] = None,
217
+ dk: Optional[torch.Tensor] = None,
218
+ dv: Optional[torch.Tensor] = None,
219
+ alibi_slopes: Optional[torch.Tensor] = None,
220
+ max_seqlen_q: int = 0,
221
+ max_seqlen_k: int = 0,
222
+ p_dropout: float = 0.0,
223
+ softmax_scale: float = 1.0,
224
+ zero_tensors: bool = False,
225
+ is_causal: bool = False,
226
+ window_size_left: int = -1,
227
+ window_size_right: int = -1,
228
+ softcap: float = 0.0,
229
+ deterministic: bool = False,
230
+ gen: Optional[torch.Generator] = None,
231
+ rng_state: Optional[torch.Tensor] = None,
232
+ ) -> List[torch.Tensor]:
233
+ """
234
+ Backward pass for multi-head attention with variable sequence lengths.
235
+
236
+ Args:
237
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
238
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
239
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
240
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
241
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
242
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
243
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
244
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
245
+ dq: Optional gradient tensor for queries, same shape as q
246
+ dk: Optional gradient tensor for keys, same shape as k
247
+ dv: Optional gradient tensor for values, same shape as v
248
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
249
+ max_seqlen_q: Maximum sequence length for queries
250
+ max_seqlen_k: Maximum sequence length for keys
251
+ p_dropout: Dropout probability
252
+ softmax_scale: Scale factor for softmax
253
+ zero_tensors: Whether to zero tensors before computation
254
+ is_causal: Whether to use causal attention
255
+ window_size_left: Window size for left context (-1 for unlimited)
256
+ window_size_right: Window size for right context (-1 for unlimited)
257
+ softcap: Soft cap for attention weights
258
+ deterministic: Whether to use deterministic algorithms
259
+ gen: Optional random number generator
260
+ rng_state: Optional RNG state from forward pass
261
+
262
+ Returns:
263
+ List of tensors: [dq, dk, dv]
264
+ """
265
+ return ops.mha_varlen_bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ dq,
273
+ dk,
274
+ dv,
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ alibi_slopes,
278
+ max_seqlen_q,
279
+ max_seqlen_k,
280
+ p_dropout,
281
+ softmax_scale,
282
+ zero_tensors,
283
+ is_causal,
284
+ window_size_left,
285
+ window_size_right,
286
+ softcap,
287
+ deterministic,
288
+ gen,
289
+ rng_state,
290
+ )
291
+
292
+
293
+ def mha_fwd_kvcache(
294
+ q: torch.Tensor,
295
+ kcache: torch.Tensor,
296
+ vcache: torch.Tensor,
297
+ k: Optional[torch.Tensor] = None,
298
+ v: Optional[torch.Tensor] = None,
299
+ seqlens_k: Optional[torch.Tensor] = None,
300
+ rotary_cos: Optional[torch.Tensor] = None,
301
+ rotary_sin: Optional[torch.Tensor] = None,
302
+ cache_batch_idx: Optional[torch.Tensor] = None,
303
+ leftpad_k: Optional[torch.Tensor] = None,
304
+ block_table: Optional[torch.Tensor] = None,
305
+ alibi_slopes: Optional[torch.Tensor] = None,
306
+ out: Optional[torch.Tensor] = None,
307
+ softmax_scale: float = 1.0,
308
+ is_causal: bool = False,
309
+ window_size_left: int = -1,
310
+ window_size_right: int = -1,
311
+ softcap: float = 0.0,
312
+ is_rotary_interleaved: bool = False,
313
+ num_splits: int = 1,
314
+ ) -> List[torch.Tensor]:
315
+ """
316
+ Forward pass for multi-head attention with KV cache.
317
+
318
+ Args:
319
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
320
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
321
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
322
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
323
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
324
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
325
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
326
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
327
+ cache_batch_idx: Optional indices to index into the KV cache
328
+ leftpad_k: Optional left padding for keys of shape [batch_size]
329
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
330
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
331
+ out: Optional output tensor, same shape as q
332
+ softmax_scale: Scale factor for softmax
333
+ is_causal: Whether to use causal attention
334
+ window_size_left: Window size for left context (-1 for unlimited)
335
+ window_size_right: Window size for right context (-1 for unlimited)
336
+ softcap: Soft cap for attention weights
337
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
338
+ num_splits: Number of splits for computation
339
+
340
+ Returns:
341
+ List of tensors: [output, softmax_lse]
342
+ """
343
+ return ops.mha_fwd_kvcache(
344
+ q,
345
+ kcache,
346
+ vcache,
347
+ k,
348
+ v,
349
+ seqlens_k,
350
+ rotary_cos,
351
+ rotary_sin,
352
+ cache_batch_idx,
353
+ leftpad_k,
354
+ block_table,
355
+ alibi_slopes,
356
+ out,
357
+ softmax_scale,
358
+ is_causal,
359
+ window_size_left,
360
+ window_size_right,
361
+ softcap,
362
+ is_rotary_interleaved,
363
+ num_splits,
364
+ )