drbh commited on
Commit
d6cc1b0
·
1 Parent(s): b833fce

feat improve readme and library code

Browse files
Files changed (2) hide show
  1. README.md +80 -0
  2. torch-ext/flash_attn/__init__.py +343 -16
README.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Flash Attention
2
+
3
+ Flash Attention is a fast and memory-efficient implementation of the attention mechanism, designed to work with large models and long sequences. This is a Hugging Face compliant kernel build of Flash Attention.
4
+
5
+ Original code here [https://github.com/Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention).
6
+
7
+ ```python
8
+ # /// script
9
+ # dependencies = ["numpy", "torch", "kernels"]
10
+ # ///
11
+ import torch
12
+ from kernels import get_kernel
13
+
14
+ # Setup
15
+ torch.manual_seed(42)
16
+ flash_attn = get_kernel("kernels-community/flash-attn")
17
+ device = torch.device("cuda")
18
+
19
+ # Show available functions
20
+ print("Flash Attention functions:", [i for i in dir(flash_attn) if i.startswith("mha")])
21
+
22
+ # 1. Standard attention
23
+ print("\n1. Standard attention:")
24
+ B, S, H, D = 2, 5, 4, 8 # batch, seq_len, heads, head_dim
25
+ q = k = v = torch.randn(B, S, H, D, device=device, dtype=torch.float16)
26
+ out = flash_attn.mha_fwd(q=q, k=k, v=v, is_causal=False)[0]
27
+ print(f"Output: {out.shape}")
28
+
29
+ # 2. Variable length sequences
30
+ print("\n2. Variable length sequences:")
31
+ q_var = torch.randn(10, H, D, device=device, dtype=torch.float16) # total_q=10
32
+ k_var = v_var = torch.randn(12, H, D, device=device, dtype=torch.float16) # total_k=12
33
+ # For 3 sequences with lengths [3,4,3] for q and [4,5,3] for k
34
+ cu_q = torch.tensor([0, 3, 7, 10], device=device, dtype=torch.int32)
35
+ cu_k = torch.tensor([0, 4, 9, 12], device=device, dtype=torch.int32)
36
+ out_var = flash_attn.mha_varlen_fwd(
37
+ q=q_var,
38
+ k=k_var,
39
+ v=v_var,
40
+ cu_seqlens_q=cu_q,
41
+ cu_seqlens_k=cu_k,
42
+ max_seqlen_q=4,
43
+ max_seqlen_k=5,
44
+ )[0]
45
+ print(f"Output: {out_var.shape}")
46
+
47
+ # 3. KV-cache for autoregressive generation
48
+ print("\n3. KV-cache:")
49
+ cache_len, new_len = 10, 2
50
+ kcache = vcache = torch.randn(B, cache_len, H, D, device=device, dtype=torch.float16)
51
+ q_new = k_new = v_new = torch.randn(
52
+ B, new_len, H, D, device=device, dtype=torch.float16
53
+ )
54
+ seqlens = torch.full((B,), cache_len + new_len, device=device, dtype=torch.int32)
55
+ out_kv = flash_attn.mha_fwd_kvcache(
56
+ q=q_new,
57
+ kcache=kcache,
58
+ vcache=vcache,
59
+ k=k_new,
60
+ v=v_new,
61
+ seqlens_k=seqlens,
62
+ is_causal=True,
63
+ )[0]
64
+ print(f"Output: {out_kv.shape}")
65
+ ```
66
+
67
+ expected output
68
+ ```txt
69
+ Fetching 3 files: 100%|█████████████████████████████████████████████████████| 3/3 [00:00<00:00, 16384.00it/s]
70
+ Flash Attention functions: ['mha_bwd', 'mha_fwd', 'mha_fwd_kvcache', 'mha_varlen_bwd', 'mha_varlen_fwd']
71
+
72
+ 1. Standard attention:
73
+ Output: torch.Size([2, 5, 4, 8])
74
+
75
+ 2. Variable length sequences:
76
+ Output: torch.Size([10, 4, 8])
77
+
78
+ 3. KV-cache:
79
+ Output: torch.Size([2, 2, 4, 8])
80
+ ```
torch-ext/flash_attn/__init__.py CHANGED
@@ -1,25 +1,45 @@
1
- from typing import Optional
2
-
3
  import torch
4
-
5
  from ._ops import ops
6
 
 
7
  def mha_fwd(
8
  q: torch.Tensor,
9
  k: torch.Tensor,
10
  v: torch.Tensor,
11
- out: torch.Tensor,
12
- alibi_slopes: torch.Tensor,
13
- p_dropout: float,
14
- softmax_scale: float,
15
- is_causal: bool,
16
- window_size_left: int,
17
- window_size_right: int,
18
- softcap: float,
19
- return_softmax: bool,
20
- gen: Optional[torch.Generator],
21
- ) -> torch.Tensor:
22
- ops.mha_fwd(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  q,
24
  k,
25
  v,
@@ -34,4 +54,311 @@ def mha_fwd(
34
  return_softmax,
35
  gen,
36
  )
37
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
 
2
  import torch
 
3
  from ._ops import ops
4
 
5
+
6
  def mha_fwd(
7
  q: torch.Tensor,
8
  k: torch.Tensor,
9
  v: torch.Tensor,
10
+ out: Optional[torch.Tensor] = None,
11
+ alibi_slopes: Optional[torch.Tensor] = None,
12
+ p_dropout: float = 0.0,
13
+ softmax_scale: float = 1.0,
14
+ is_causal: bool = False,
15
+ window_size_left: int = -1,
16
+ window_size_right: int = -1,
17
+ softcap: float = 0.0,
18
+ return_softmax: bool = False,
19
+ gen: Optional[torch.Generator] = None,
20
+ ) -> List[torch.Tensor]:
21
+ """
22
+ Forward pass for multi-head attention.
23
+
24
+ Args:
25
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
26
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
27
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
28
+ out: Optional output tensor, same shape as q
29
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
30
+ p_dropout: Dropout probability
31
+ softmax_scale: Scale factor for softmax
32
+ is_causal: Whether to use causal attention
33
+ window_size_left: Window size for left context (-1 for unlimited)
34
+ window_size_right: Window size for right context (-1 for unlimited)
35
+ softcap: Soft cap for attention weights
36
+ return_softmax: Whether to return softmax weights
37
+ gen: Optional random number generator
38
+
39
+ Returns:
40
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
41
+ """
42
+ return ops.mha_fwd(
43
  q,
44
  k,
45
  v,
 
54
  return_softmax,
55
  gen,
56
  )
57
+
58
+
59
+ def mha_varlen_fwd(
60
+ q: torch.Tensor,
61
+ k: torch.Tensor,
62
+ v: torch.Tensor,
63
+ cu_seqlens_q: torch.Tensor,
64
+ cu_seqlens_k: torch.Tensor,
65
+ out: Optional[torch.Tensor] = None,
66
+ seqused_k: Optional[torch.Tensor] = None,
67
+ leftpad_k: Optional[torch.Tensor] = None,
68
+ block_table: Optional[torch.Tensor] = None,
69
+ alibi_slopes: Optional[torch.Tensor] = None,
70
+ max_seqlen_q: int = 0,
71
+ max_seqlen_k: int = 0,
72
+ p_dropout: float = 0.0,
73
+ softmax_scale: float = 1.0,
74
+ zero_tensors: bool = False,
75
+ is_causal: bool = False,
76
+ window_size_left: int = -1,
77
+ window_size_right: int = -1,
78
+ softcap: float = 0.0,
79
+ return_softmax: bool = False,
80
+ gen: Optional[torch.Generator] = None,
81
+ ) -> List[torch.Tensor]:
82
+ """
83
+ Forward pass for multi-head attention with variable sequence lengths.
84
+
85
+ Args:
86
+ q: Query tensor of shape [total_q, num_heads, head_size]
87
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
88
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
89
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
90
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
91
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
92
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
93
+ leftpad_k: Optional left padding for keys of shape [batch_size]
94
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
95
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
96
+ max_seqlen_q: Maximum sequence length for queries
97
+ max_seqlen_k: Maximum sequence length for keys
98
+ p_dropout: Dropout probability
99
+ softmax_scale: Scale factor for softmax
100
+ zero_tensors: Whether to zero tensors before computation
101
+ is_causal: Whether to use causal attention
102
+ window_size_left: Window size for left context (-1 for unlimited)
103
+ window_size_right: Window size for right context (-1 for unlimited)
104
+ softcap: Soft cap for attention weights
105
+ return_softmax: Whether to return softmax weights
106
+ gen: Optional random number generator
107
+
108
+ Returns:
109
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
110
+ """
111
+ return ops.mha_varlen_fwd(
112
+ q,
113
+ k,
114
+ v,
115
+ out,
116
+ cu_seqlens_q,
117
+ cu_seqlens_k,
118
+ seqused_k,
119
+ leftpad_k,
120
+ block_table,
121
+ alibi_slopes,
122
+ max_seqlen_q,
123
+ max_seqlen_k,
124
+ p_dropout,
125
+ softmax_scale,
126
+ zero_tensors,
127
+ is_causal,
128
+ window_size_left,
129
+ window_size_right,
130
+ softcap,
131
+ return_softmax,
132
+ gen,
133
+ )
134
+
135
+
136
+ def mha_bwd(
137
+ dout: torch.Tensor,
138
+ q: torch.Tensor,
139
+ k: torch.Tensor,
140
+ v: torch.Tensor,
141
+ out: torch.Tensor,
142
+ softmax_lse: torch.Tensor,
143
+ dq: Optional[torch.Tensor] = None,
144
+ dk: Optional[torch.Tensor] = None,
145
+ dv: Optional[torch.Tensor] = None,
146
+ alibi_slopes: Optional[torch.Tensor] = None,
147
+ p_dropout: float = 0.0,
148
+ softmax_scale: float = 1.0,
149
+ is_causal: bool = False,
150
+ window_size_left: int = -1,
151
+ window_size_right: int = -1,
152
+ softcap: float = 0.0,
153
+ deterministic: bool = False,
154
+ gen: Optional[torch.Generator] = None,
155
+ rng_state: Optional[torch.Tensor] = None,
156
+ ) -> List[torch.Tensor]:
157
+ """
158
+ Backward pass for multi-head attention.
159
+
160
+ Args:
161
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
162
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
163
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
164
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
165
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
166
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
167
+ dq: Optional gradient tensor for queries, same shape as q
168
+ dk: Optional gradient tensor for keys, same shape as k
169
+ dv: Optional gradient tensor for values, same shape as v
170
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
171
+ p_dropout: Dropout probability
172
+ softmax_scale: Scale factor for softmax
173
+ is_causal: Whether to use causal attention
174
+ window_size_left: Window size for left context (-1 for unlimited)
175
+ window_size_right: Window size for right context (-1 for unlimited)
176
+ softcap: Soft cap for attention weights
177
+ deterministic: Whether to use deterministic algorithms
178
+ gen: Optional random number generator
179
+ rng_state: Optional RNG state from forward pass
180
+
181
+ Returns:
182
+ List of tensors: [dq, dk, dv]
183
+ """
184
+ return ops.mha_bwd(
185
+ dout,
186
+ q,
187
+ k,
188
+ v,
189
+ out,
190
+ softmax_lse,
191
+ dq,
192
+ dk,
193
+ dv,
194
+ alibi_slopes,
195
+ p_dropout,
196
+ softmax_scale,
197
+ is_causal,
198
+ window_size_left,
199
+ window_size_right,
200
+ softcap,
201
+ deterministic,
202
+ gen,
203
+ rng_state,
204
+ )
205
+
206
+
207
+ def mha_varlen_bwd(
208
+ dout: torch.Tensor,
209
+ q: torch.Tensor,
210
+ k: torch.Tensor,
211
+ v: torch.Tensor,
212
+ out: torch.Tensor,
213
+ softmax_lse: torch.Tensor,
214
+ cu_seqlens_q: torch.Tensor,
215
+ cu_seqlens_k: torch.Tensor,
216
+ dq: Optional[torch.Tensor] = None,
217
+ dk: Optional[torch.Tensor] = None,
218
+ dv: Optional[torch.Tensor] = None,
219
+ alibi_slopes: Optional[torch.Tensor] = None,
220
+ max_seqlen_q: int = 0,
221
+ max_seqlen_k: int = 0,
222
+ p_dropout: float = 0.0,
223
+ softmax_scale: float = 1.0,
224
+ zero_tensors: bool = False,
225
+ is_causal: bool = False,
226
+ window_size_left: int = -1,
227
+ window_size_right: int = -1,
228
+ softcap: float = 0.0,
229
+ deterministic: bool = False,
230
+ gen: Optional[torch.Generator] = None,
231
+ rng_state: Optional[torch.Tensor] = None,
232
+ ) -> List[torch.Tensor]:
233
+ """
234
+ Backward pass for multi-head attention with variable sequence lengths.
235
+
236
+ Args:
237
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
238
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
239
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
240
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
241
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
242
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
243
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
244
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
245
+ dq: Optional gradient tensor for queries, same shape as q
246
+ dk: Optional gradient tensor for keys, same shape as k
247
+ dv: Optional gradient tensor for values, same shape as v
248
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
249
+ max_seqlen_q: Maximum sequence length for queries
250
+ max_seqlen_k: Maximum sequence length for keys
251
+ p_dropout: Dropout probability
252
+ softmax_scale: Scale factor for softmax
253
+ zero_tensors: Whether to zero tensors before computation
254
+ is_causal: Whether to use causal attention
255
+ window_size_left: Window size for left context (-1 for unlimited)
256
+ window_size_right: Window size for right context (-1 for unlimited)
257
+ softcap: Soft cap for attention weights
258
+ deterministic: Whether to use deterministic algorithms
259
+ gen: Optional random number generator
260
+ rng_state: Optional RNG state from forward pass
261
+
262
+ Returns:
263
+ List of tensors: [dq, dk, dv]
264
+ """
265
+ return ops.mha_varlen_bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ dq,
273
+ dk,
274
+ dv,
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ alibi_slopes,
278
+ max_seqlen_q,
279
+ max_seqlen_k,
280
+ p_dropout,
281
+ softmax_scale,
282
+ zero_tensors,
283
+ is_causal,
284
+ window_size_left,
285
+ window_size_right,
286
+ softcap,
287
+ deterministic,
288
+ gen,
289
+ rng_state,
290
+ )
291
+
292
+
293
+ def mha_fwd_kvcache(
294
+ q: torch.Tensor,
295
+ kcache: torch.Tensor,
296
+ vcache: torch.Tensor,
297
+ k: Optional[torch.Tensor] = None,
298
+ v: Optional[torch.Tensor] = None,
299
+ seqlens_k: Optional[torch.Tensor] = None,
300
+ rotary_cos: Optional[torch.Tensor] = None,
301
+ rotary_sin: Optional[torch.Tensor] = None,
302
+ cache_batch_idx: Optional[torch.Tensor] = None,
303
+ leftpad_k: Optional[torch.Tensor] = None,
304
+ block_table: Optional[torch.Tensor] = None,
305
+ alibi_slopes: Optional[torch.Tensor] = None,
306
+ out: Optional[torch.Tensor] = None,
307
+ softmax_scale: float = 1.0,
308
+ is_causal: bool = False,
309
+ window_size_left: int = -1,
310
+ window_size_right: int = -1,
311
+ softcap: float = 0.0,
312
+ is_rotary_interleaved: bool = False,
313
+ num_splits: int = 1,
314
+ ) -> List[torch.Tensor]:
315
+ """
316
+ Forward pass for multi-head attention with KV cache.
317
+
318
+ Args:
319
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
320
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
321
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
322
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
323
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
324
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
325
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
326
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
327
+ cache_batch_idx: Optional indices to index into the KV cache
328
+ leftpad_k: Optional left padding for keys of shape [batch_size]
329
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
330
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
331
+ out: Optional output tensor, same shape as q
332
+ softmax_scale: Scale factor for softmax
333
+ is_causal: Whether to use causal attention
334
+ window_size_left: Window size for left context (-1 for unlimited)
335
+ window_size_right: Window size for right context (-1 for unlimited)
336
+ softcap: Soft cap for attention weights
337
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
338
+ num_splits: Number of splits for computation
339
+
340
+ Returns:
341
+ List of tensors: [output, softmax_lse]
342
+ """
343
+ return ops.mha_fwd_kvcache(
344
+ q,
345
+ kcache,
346
+ vcache,
347
+ k,
348
+ v,
349
+ seqlens_k,
350
+ rotary_cos,
351
+ rotary_sin,
352
+ cache_batch_idx,
353
+ leftpad_k,
354
+ block_table,
355
+ alibi_slopes,
356
+ out,
357
+ softmax_scale,
358
+ is_causal,
359
+ window_size_left,
360
+ window_size_right,
361
+ softcap,
362
+ is_rotary_interleaved,
363
+ num_splits,
364
+ )