|
# Flash Attention |
|
|
|
Flash Attention is a fast and memory-efficient implementation of the attention mechanism, designed to work with large models and long sequences. This is a Hugging Face compliant kernel build of Flash Attention. |
|
|
|
Original code here [https://github.com/Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention). |
|
|
|
```python |
|
# /// script |
|
# dependencies = ["numpy", "torch", "kernels"] |
|
# /// |
|
import torch |
|
from kernels import get_kernel |
|
|
|
# Setup |
|
torch.manual_seed(42) |
|
flash_attn = get_kernel("kernels-community/flash-attn") |
|
device = torch.device("cuda") |
|
|
|
# Show available functions |
|
print("Flash Attention functions:", [i for i in dir(flash_attn) if i.startswith("mha")]) |
|
|
|
# 1. Standard attention |
|
print("\n1. Standard attention:") |
|
B, S, H, D = 2, 5, 4, 8 # batch, seq_len, heads, head_dim |
|
q = k = v = torch.randn(B, S, H, D, device=device, dtype=torch.float16) |
|
out = flash_attn.mha_fwd(q=q, k=k, v=v, is_causal=False)[0] |
|
print(f"Output: {out.shape}") |
|
|
|
# 2. Variable length sequences |
|
print("\n2. Variable length sequences:") |
|
q_var = torch.randn(10, H, D, device=device, dtype=torch.float16) # total_q=10 |
|
k_var = v_var = torch.randn(12, H, D, device=device, dtype=torch.float16) # total_k=12 |
|
# For 3 sequences with lengths [3,4,3] for q and [4,5,3] for k |
|
cu_q = torch.tensor([0, 3, 7, 10], device=device, dtype=torch.int32) |
|
cu_k = torch.tensor([0, 4, 9, 12], device=device, dtype=torch.int32) |
|
out_var = flash_attn.mha_varlen_fwd( |
|
q=q_var, |
|
k=k_var, |
|
v=v_var, |
|
cu_seqlens_q=cu_q, |
|
cu_seqlens_k=cu_k, |
|
max_seqlen_q=4, |
|
max_seqlen_k=5, |
|
)[0] |
|
print(f"Output: {out_var.shape}") |
|
|
|
# 3. KV-cache for autoregressive generation |
|
print("\n3. KV-cache:") |
|
cache_len, new_len = 10, 2 |
|
kcache = vcache = torch.randn(B, cache_len, H, D, device=device, dtype=torch.float16) |
|
q_new = k_new = v_new = torch.randn( |
|
B, new_len, H, D, device=device, dtype=torch.float16 |
|
) |
|
seqlens = torch.full((B,), cache_len + new_len, device=device, dtype=torch.int32) |
|
out_kv = flash_attn.mha_fwd_kvcache( |
|
q=q_new, |
|
kcache=kcache, |
|
vcache=vcache, |
|
k=k_new, |
|
v=v_new, |
|
seqlens_k=seqlens, |
|
is_causal=True, |
|
)[0] |
|
print(f"Output: {out_kv.shape}") |
|
``` |
|
|
|
expected output |
|
```txt |
|
Fetching 3 files: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββ| 3/3 [00:00<00:00, 16384.00it/s] |
|
Flash Attention functions: ['mha_bwd', 'mha_fwd', 'mha_fwd_kvcache', 'mha_varlen_bwd', 'mha_varlen_fwd'] |
|
|
|
1. Standard attention: |
|
Output: torch.Size([2, 5, 4, 8]) |
|
|
|
2. Variable length sequences: |
|
Output: torch.Size([10, 4, 8]) |
|
|
|
3. KV-cache: |
|
Output: torch.Size([2, 2, 4, 8]) |
|
``` |